Add: upload controller in slave mode

pull/247/head
HFO4 5 years ago
parent 4f8558d1e8
commit 6470340104

@ -3,6 +3,7 @@ package filesystem
import ( import (
"context" "context"
"github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/filesystem/local"
"github.com/HFO4/cloudreve/pkg/filesystem/response" "github.com/HFO4/cloudreve/pkg/filesystem/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -90,12 +91,15 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
User: &model.User{}, User: &model.User{},
} }
anonymousGroup, err := model.GetGroupByID(3) // 如果是主机模式下,则为匿名文件系统分配游客用户组
if err != nil { if conf.SystemConfig.Mode == "master" {
return nil, err anonymousGroup, err := model.GetGroupByID(3)
if err != nil {
return nil, err
}
fs.User.Group = anonymousGroup
} }
fs.User.Group = anonymousGroup
return fs, nil return fs, nil
} }

@ -15,4 +15,6 @@ const (
FileModelCtx FileModelCtx
// HTTPCtx HTTP请求的上下文 // HTTPCtx HTTP请求的上下文
HTTPCtx HTTPCtx
// UploadPolicyCtx 上传策略一般为slave模式下使用
UploadPolicyCtx
) )

@ -6,6 +6,7 @@ import (
model "github.com/HFO4/cloudreve/models" model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"io/ioutil" "io/ioutil"
"strings" "strings"
@ -52,6 +53,30 @@ func HookIsFileExist(ctx context.Context, fs *FileSystem) error {
return ErrObjectNotExist return ErrObjectNotExist
} }
// HookSlaveUploadValidate Slave模式下对文件上传的一系列验证
// TODO 测试
func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem) error {
file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy)
// 验证单文件尺寸
if file.GetSize() > policy.MaxSize {
return ErrFileSizeTooBig
}
// 验证文件名
if !fs.ValidateLegalName(ctx, file.GetFileName()) {
return ErrIllegalObjectName
}
// 验证扩展名
if len(policy.AllowedExtension) > 0 && !IsInExtensionList(policy.AllowedExtension, file.GetFileName()) {
return ErrFileExtensionNotAllowed
}
return nil
}
// HookValidateFile 一系列对文件检验的集合 // HookValidateFile 一系列对文件检验的集合
func HookValidateFile(ctx context.Context, fs *FileSystem) error { func HookValidateFile(ctx context.Context, fs *FileSystem) error {
file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader) file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)

@ -6,8 +6,10 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/HFO4/cloudreve/models" model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/HFO4/cloudreve/pkg/filesystem/local"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock" testMock "github.com/stretchr/testify/mock"
@ -477,3 +479,64 @@ func TestGenericAfterUpdate(t *testing.T) {
asserts.Error(err) asserts.Error(err)
} }
} }
func TestHookSlaveUploadValidate(t *testing.T) {
asserts := assert.New(t)
conf.SystemConfig.Mode = "slave"
fs, err := NewAnonymousFileSystem()
conf.SystemConfig.Mode = "master"
asserts.NoError(err)
// 正常
{
policy := serializer.UploadPolicy{
SavePath: "",
MaxSize: 10,
AllowedExtension: nil,
}
file := local.FileStream{Name: "1.txt", Size: 10}
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
asserts.NoError(HookSlaveUploadValidate(ctx, fs))
}
// 尺寸太大
{
policy := serializer.UploadPolicy{
SavePath: "",
MaxSize: 10,
AllowedExtension: nil,
}
file := local.FileStream{Name: "1.txt", Size: 11}
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
asserts.Equal(ErrFileSizeTooBig, HookSlaveUploadValidate(ctx, fs))
}
// 文件名非法
{
policy := serializer.UploadPolicy{
SavePath: "",
MaxSize: 10,
AllowedExtension: nil,
}
file := local.FileStream{Name: "/1.txt", Size: 10}
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
asserts.Equal(ErrIllegalObjectName, HookSlaveUploadValidate(ctx, fs))
}
// 扩展名非法
{
policy := serializer.UploadPolicy{
SavePath: "",
MaxSize: 10,
AllowedExtension: []string{"jpg"},
}
file := local.FileStream{Name: "1.txt", Size: 10}
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
asserts.Equal(ErrFileExtensionNotAllowed, HookSlaveUploadValidate(ctx, fs))
}
}

@ -24,14 +24,15 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
return err return err
} }
// 生成文件名和路径, 如果是更新操作就从原始文件获取 // 生成文件名和路径,
var savePath string var savePath string
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) // 如果是更新操作就从上下文中获取
if ok { if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
savePath = originFile.SourceName savePath = originFile.SourceName
} else { } else {
savePath = fs.GenerateSavePath(ctx, file) savePath = fs.GenerateSavePath(ctx, file)
} }
ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath)
// 处理客户端未完成上传时,关闭连接 // 处理客户端未完成上传时,关闭连接
go fs.CancelUpload(ctx, savePath, file) go fs.CancelUpload(ctx, savePath, file)
@ -43,7 +44,6 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
} }
// 上传完成后的钩子 // 上传完成后的钩子
ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath)
err = fs.Trigger(ctx, fs.AfterUpload) err = fs.Trigger(ctx, fs.AfterUpload)
if err != nil { if err != nil {
@ -57,21 +57,42 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
return err return err
} }
util.Log().Info("新文件PUT:%s , 大小:%d, 上传者:%s", file.GetFileName(), file.GetSize(), fs.User.Nick) util.Log().Info(
"新文件PUT:%s , 大小:%d, 上传者:%s",
file.GetFileName(),
file.GetSize(),
fs.User.Nick,
)
return nil return nil
} }
// GenerateSavePath 生成要存放文件的路径 // GenerateSavePath 生成要存放文件的路径
// TODO 完善测试
func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string { func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string {
if fs.User.Model.ID != 0 {
return filepath.Join(
fs.User.Policy.GeneratePath(
fs.User.Model.ID,
file.GetVirtualPath(),
),
fs.User.Policy.GenerateFileName(
fs.User.Model.ID,
file.GetFileName(),
),
)
}
// 匿名文件系统使用空上传策略生成路径
nilPolicy := model.Policy{}
return filepath.Join( return filepath.Join(
fs.User.Policy.GeneratePath( nilPolicy.GeneratePath(
fs.User.Model.ID, 0,
file.GetVirtualPath(), "",
), ),
fs.User.Policy.GenerateFileName( nilPolicy.GenerateFileName(
fs.User.Model.ID, 0,
file.GetFileName(), "",
), ),
) )
} }

@ -1,10 +1,33 @@
package serializer package serializer
import (
"encoding/base64"
"encoding/json"
)
// UploadPolicy slave模式下传递的上传策略 // UploadPolicy slave模式下传递的上传策略
type UploadPolicy struct { type UploadPolicy struct {
SavePath string `json:"save_path"` SavePath string `json:"save_path"`
MaxSize uint64 `json:"save_path"` MaxSize uint64 `json:"max_size"`
AllowedExtension []string `json:"allowed_extension"` AllowedExtension []string `json:"allowed_extension"`
CallbackURL string `json:"callback_url"` CallbackURL string `json:"callback_url"`
CallbackKey string `json:"callback_key"` CallbackKey string `json:"callback_key"`
} }
// DecodeUploadPolicy 反序列化Header中携带的上传策略
// TODO 测试
func DecodeUploadPolicy(raw string) (*UploadPolicy, error) {
var res UploadPolicy
rawJSON, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
err = json.Unmarshal(rawJSON, &res)
if err != nil {
return nil, err
}
return &res, err
}

@ -0,0 +1,55 @@
package serializer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestDecodeUploadPolicy(t *testing.T) {
asserts := assert.New(t)
testCases := []struct {
input string
expectError bool
expectNil bool
expectRes *UploadPolicy
}{
{
"错误的base64字符",
true,
true,
&UploadPolicy{},
},
{
"6ZSZ6K+v55qESlNPTuWtl+espg==",
true,
true,
&UploadPolicy{},
},
{
"e30=",
false,
false,
&UploadPolicy{},
},
{
"eyJjYWxsYmFja19rZXkiOiJ0ZXN0In0=",
false,
false,
&UploadPolicy{CallbackKey: "test"},
},
}
for _, testCase := range testCases {
res, err := DecodeUploadPolicy(testCase.input)
if testCase.expectError {
asserts.Error(err)
}
if testCase.expectNil {
asserts.Nil(res)
}
if !testCase.expectNil {
asserts.Equal(testCase.expectRes, res)
}
}
}

@ -1,12 +1,75 @@
package controllers package controllers
import ( import (
"context"
"github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/local"
"github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/serializer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/url"
"strconv"
) )
// SlaveUpload 从机文件上传 // SlaveUpload 从机文件上传
func SlaveUpload(c *gin.Context) { func SlaveUpload(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
defer cancel()
// 创建匿名文件系统
fs, err := filesystem.NewAnonymousFileSystem()
if err != nil {
c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err))
return
}
// 从请求中取得上传策略
uploadPolicyRaw := c.GetHeader("X-Policy")
if uploadPolicyRaw == "" {
c.JSON(200, serializer.ParamErr("未指定上传策略", nil))
}
// 解析上传策略
uploadPolicy, err := serializer.DecodeUploadPolicy(uploadPolicyRaw)
if err != nil {
c.JSON(200, serializer.ParamErr("上传策略格式有误", err))
}
ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, uploadPolicy)
// 取得文件大小
fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64)
if err != nil {
c.JSON(200, ErrorResponse(err))
return
}
// 解码文件名和路径
fileName, err := url.QueryUnescape(c.Request.Header.Get("X-FileName"))
if err != nil {
c.JSON(200, ErrorResponse(err))
return
}
fileData := local.FileStream{
MIMEType: c.Request.Header.Get("Content-Type"),
File: c.Request.Body,
Name: fileName,
Size: fileSize,
}
// 给文件系统分配钩子
fs.Use("BeforeUpload", filesystem.HookSlaveUploadValidate)
fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile)
fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
// 执行上传
err = fs.Upload(ctx, fileData)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err))
return
}
c.JSON(200, serializer.Response{ c.JSON(200, serializer.Response{
Code: 0, Code: 0,

Loading…
Cancel
Save