Fix: storage policy should be re-dispatched according to policy id in upload session

pull/247/head
HFO4 5 years ago
parent b862ddb666
commit 68d4a86166

@ -132,11 +132,6 @@ func uploadCallbackCheck(c *gin.Context) (serializer.Response, *model.User) {
}
c.Set("user", &user)
// 检查存储策略是否一致
if user.GetPolicyID() != callbackSession.PolicyID {
return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略已变更,请重新上传", nil), nil
}
return serializer.Response{}, &user
}

@ -277,36 +277,6 @@ func TestRemoteCallbackAuth(t *testing.T) {
asserts.True(c.IsAborted())
}
// 存储策略不一致
{
cache.Set(
"callback_testCallBackRemote",
serializer.UploadSession{
UID: 1,
PolicyID: 2,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[3]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(3, "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 签名错误
{
cache.Set(

@ -105,7 +105,7 @@ func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *
if file != nil {
// 切换上传策略
fs.Policy = file.GetPolicy()
err := fs.dispatchHandler()
err := fs.DispatchHandler()
if err != nil {
util.Log().Warning("无法压缩文件%s%s", file.Name, err)
return

@ -79,7 +79,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
// 重设上传策略
fs.Policy = &model.Policy{Type: "local"}
_ = fs.dispatchHandler()
_ = fs.DispatchHandler()
// 获取文件流
rs, err := fs.Handler.Get(ctx, path)
@ -184,7 +184,7 @@ func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*m
// 切换上传策略
fs.Policy = toBeDeletedFiles[0].GetPolicy()
err := fs.dispatchHandler()
err := fs.DispatchHandler()
if err != nil {
failed[policyID] = sourceNames
continue
@ -327,7 +327,7 @@ func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error {
}
fs.Policy = fs.FileTarget[0].GetPolicy()
err := fs.dispatchHandler()
err := fs.DispatchHandler()
if err != nil {
return err
}

@ -109,7 +109,7 @@ func NewFileSystem(user *model.User) (*FileSystem, error) {
fs := getEmptyFS()
fs.User = user
// 分配存储策略适配器
err := fs.dispatchHandler()
err := fs.DispatchHandler()
// TODO 分配默认钩子
return fs, err
@ -135,9 +135,9 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
return fs, nil
}
// dispatchHandler 根据存储策略分配文件适配器
// DispatchHandler 根据存储策略分配文件适配器
// TODO 完善测试
func (fs *FileSystem) dispatchHandler() error {
func (fs *FileSystem) DispatchHandler() error {
var policyType string
var currentPolicy *model.Policy

@ -64,13 +64,13 @@ func TestDispatchHandler(t *testing.T) {
}
// 未指定,使用用户默认
err := fs.dispatchHandler()
err := fs.DispatchHandler()
asserts.NoError(err)
asserts.IsType(local.Driver{}, fs.Handler)
// 已指定,发生错误
fs.Policy = &model.Policy{Type: "unknown"}
err = fs.dispatchHandler()
err = fs.DispatchHandler()
asserts.Error(err)
}

@ -106,7 +106,7 @@ func HookResetPolicy(ctx context.Context, fs *FileSystem) error {
}
fs.Policy = originFile.GetPolicy()
return fs.dispatchHandler()
return fs.DispatchHandler()
}
// HookValidateCapacity 验证并扣除用户容量,包含数据库操作

@ -2,6 +2,7 @@ package callback
import (
"context"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/local"
@ -61,6 +62,17 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.
}
callbackSession := callbackSessionRaw.(*serializer.UploadSession)
// 重新指向上传策略
policy, err := model.GetPolicyByID(callbackSession.PolicyID)
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
fs.Policy = &policy
err = fs.DispatchHandler()
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
// 获取父目录
exist, parentFolder := fs.IsPathExist(callbackSession.VirtualPath)
if !exist {

Loading…
Cancel
Save