diff --git a/middleware/auth.go b/middleware/auth.go index 731b517..885d7a4 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -97,8 +97,8 @@ func WebDAVAuth() gin.HandlerFunc { } // 密码正确? - ok, _ = expectedUser.CheckPassword(password) - if !ok { + webdav, err := model.GetWebdavByPassword(password, expectedUser.ID) + if err != nil { c.Status(http.StatusUnauthorized) c.Abort() return @@ -112,6 +112,7 @@ func WebDAVAuth() gin.HandlerFunc { } c.Set("user", &expectedUser) + c.Set("webdav", webdav) c.Next() } } diff --git a/models/folder.go b/models/folder.go index 1366f71..46b9b55 100644 --- a/models/folder.go +++ b/models/folder.go @@ -15,6 +15,7 @@ type Folder struct { Name string `gorm:"unique_index:idx_only_one_name"` ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` OwnerID uint `gorm:"index:owner_id"` + PolicyID uint // Webdav下挂载的存储策略ID // 数据库忽略字段 Position string `gorm:"-"` diff --git a/models/migration.go b/models/migration.go index 94f8a3d..60c6ffe 100644 --- a/models/migration.go +++ b/models/migration.go @@ -30,7 +30,7 @@ func migration() { DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") } DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}, - &Task{}, &Download{}, &Tag{}) + &Task{}, &Download{}, &Tag{}, &Webdav{}) // 创建初始存储策略 addDefaultPolicy() diff --git a/models/share.go b/models/share.go index e1b5290..3fe636d 100644 --- a/models/share.go +++ b/models/share.go @@ -220,6 +220,11 @@ func (share *Share) Delete() error { return DB.Model(share).Delete(share).Error } +// DeleteShareBySourceIDs 根据原始资源类型和ID删除文件 +func DeleteShareBySourceIDs(sources []uint, isDir bool) error { + return DB.Where("source_id in (?) and is_dir = ?", sources, isDir).Delete(&Share{}).Error +} + // ListShares 列出UID下的分享 func ListShares(uid uint, page, pageSize int, order string, publicOnly bool) ([]Share, int) { var ( diff --git a/models/user.go b/models/user.go index 8583711..28e3036 100644 --- a/models/user.go +++ b/models/user.go @@ -135,17 +135,20 @@ func (user *User) GetRemainingCapacity() uint64 { } // GetPolicyID 获取用户当前的存储策略ID -func (user *User) GetPolicyID() uint { +func (user *User) GetPolicyID(prefer uint) uint { + if prefer == 0 { + prefer = user.OptionsSerialized.PreferredPolicy + } // 用户未指定时,返回可用的第一个 - if user.OptionsSerialized.PreferredPolicy == 0 { + if prefer == 0 { if len(user.Group.PolicyList) != 0 { return user.Group.PolicyList[0] } return 1 } // 用户指定时,先检查是否为可用策略列表中的值 - if util.ContainsUint(user.Group.PolicyList, user.OptionsSerialized.PreferredPolicy) { - return user.OptionsSerialized.PreferredPolicy + if util.ContainsUint(user.Group.PolicyList, prefer) { + return prefer } // 不可用时,返回第一个 if len(user.Group.PolicyList) != 0 { @@ -205,7 +208,7 @@ func (user *User) AfterFind() (err error) { } // 预加载存储策略 - user.Policy, _ = GetPolicyByID(user.GetPolicyID()) + user.Policy, _ = GetPolicyByID(user.GetPolicyID(0)) return err } diff --git a/models/user_test.go b/models/user_test.go index b552876..0da77c0 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -166,7 +166,7 @@ func TestUser_GetPolicyID(t *testing.T) { for key, testCase := range testCases { newUser.OptionsSerialized.PreferredPolicy = testCase.preferred newUser.Group.PolicyList = testCase.available - asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key) + asserts.Equal(testCase.expected, newUser.GetPolicyID(0), "测试用例 #%d 未通过", key) } } diff --git a/models/webdav.go b/models/webdav.go new file mode 100644 index 0000000..2b58a8f --- /dev/null +++ b/models/webdav.go @@ -0,0 +1,19 @@ +package model + +import "github.com/jinzhu/gorm" + +// Webdav 应用账户 +type Webdav struct { + gorm.Model + Name string // 应用名称 + Password string `gorm:"unique_index:password_only_on"` // 应用密码 + UserID uint `gorm:"unique_index:password_only_on"` // 用户ID + Root string `gorm:"type:text"` // 根目录 +} + +// GetWebdavByPassword 根据密码和用户查找Webdav应用 +func GetWebdavByPassword(password string, uid uint) (*Webdav, error) { + webdav := &Webdav{} + res := DB.Where("user_id = ? and password = ?", uid, password).First(webdav) + return webdav, res.Error +} diff --git a/pkg/filesystem/driver/cos/handller.go b/pkg/filesystem/driver/cos/handller.go index 58b0005..cfed904 100644 --- a/pkg/filesystem/driver/cos/handller.go +++ b/pkg/filesystem/driver/cos/handller.go @@ -113,6 +113,10 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err failed = append(failed, v.Key) } + if len(failed) == 0 { + return failed, nil + } + return failed, errors.New("删除失败") } diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go index af4e728..716ab43 100644 --- a/pkg/filesystem/manage.go +++ b/pkg/filesystem/manage.go @@ -190,6 +190,9 @@ func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint) error { return ErrDBDeleteObjects.WithError(err) } + // 删除文件记录对应的分享记录 + model.DeleteShareBySourceIDs(allFileIDs, false) + // 归还容量 var total uint64 for _, value := range totalStorage { @@ -207,6 +210,9 @@ func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint) error { return ErrDBDeleteObjects.WithError(err) } + // 删除目录记录对应的分享记录 + model.DeleteShareBySourceIDs(allFolderIDs, true) + if notDeleted := len(fs.FileTarget) - len(deletedFileIDs); notDeleted > 0 { return serializer.NewError( serializer.CodeNotFullySuccess, diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 9f45bcb..d2a1006 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -172,7 +172,7 @@ func (fs *FileSystem) GetUploadToken(ctx context.Context, path string, size uint serializer.UploadSession{ Key: callbackKey, UID: fs.User.ID, - PolicyID: fs.User.GetPolicyID(), + PolicyID: fs.User.GetPolicyID(0), VirtualPath: path, Name: name, Size: size, diff --git a/pkg/webdav/file.go b/pkg/webdav/file.go index a5241b0..983696e 100644 --- a/pkg/webdav/file.go +++ b/pkg/webdav/file.go @@ -48,8 +48,8 @@ func moveFiles(ctx context.Context, fs *filesystem.FileSystem, src FileInfo, dst } else { err = fs.Move( ctx, - fileIDs, folderIDs, + fileIDs, src.GetPosition(), path.Dir(dst), ) @@ -81,7 +81,7 @@ func copyFiles(ctx context.Context, fs *filesystem.FileSystem, src FileInfo, dst return http.StatusInternalServerError, err } } else { - err := fs.Copy(ctx, []uint{src.(*model.File).ID}, []uint{}, src.(*model.File).Position, dst) + err := fs.Copy(ctx, []uint{}, []uint{src.(*model.File).ID}, src.(*model.File).Position, path.Dir(dst)) if err != nil { return http.StatusInternalServerError, err } diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 4229c2f..58099bd 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -356,6 +356,17 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) ctx = context.WithValue(ctx, fsctx.FileModelCtx, *originFile) } else { + // 检查父目录指定存储策略 + if exist, folder := fs.IsPathExist(filePath); exist { + if folder.PolicyID != 0 { + // 尝试获取并重设存储策略 + if policy, err := model.GetPolicyByID(fs.User.GetPolicyID(folder.PolicyID)); err == nil { + fs.User.Policy = policy + fs.DispatchHandler() + } + } + } + // 给文件系统分配钩子 fs.Use("BeforeUpload", filesystem.HookValidateFile) fs.Use("BeforeUpload", filesystem.HookValidateCapacity) diff --git a/routers/controllers/webdav.go b/routers/controllers/webdav.go index 1d72316..4f99f4a 100644 --- a/routers/controllers/webdav.go +++ b/routers/controllers/webdav.go @@ -1,6 +1,7 @@ package controllers import ( + model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/webdav" @@ -24,5 +25,18 @@ func ServeWebDAV(c *gin.Context) { return } + if webdavCtx, ok := c.Get("webdav"); ok { + application := webdavCtx.(*model.Webdav) + + // 重定根目录 + if application.Root != "/" { + if exist, root := fs.IsPathExist(application.Root); exist { + root.Position = "" + root.Name = "/" + fs.Root = root + } + } + } + handler.ServeHTTP(c.Writer, c.Request, fs) }