diff --git a/middleware/share.go b/middleware/share.go new file mode 100644 index 0000000..2154bf0 --- /dev/null +++ b/middleware/share.go @@ -0,0 +1,96 @@ +package middleware + +import ( + "fmt" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/gin-gonic/gin" +) + +// ShareAvailable 检查分享是否可用 +func ShareAvailable() gin.HandlerFunc { + return func(c *gin.Context) { + var user *model.User + if userCtx, ok := c.Get("user"); ok { + user = userCtx.(*model.User) + } else { + user = model.NewAnonymousUser() + } + + share := model.GetShareByHashID(c.Param("id")) + + if share == nil || !share.IsAvailable() { + c.JSON(200, serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil)) + c.Abort() + return + } + + c.Set("user", user) + c.Set("share", share) + c.Next() + } +} + +// ShareCanPreview 检查分享是否可被预览 +func ShareCanPreview() gin.HandlerFunc { + return func(c *gin.Context) { + if share, ok := c.Get("share"); ok { + if share.(*model.Share).PreviewEnabled { + c.Next() + return + } + c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览", + nil)) + c.Abort() + return + } + c.Abort() + } +} + +// BeforeShareDownload 分享被下载前的检查 +func BeforeShareDownload() gin.HandlerFunc { + return func(c *gin.Context) { + if shareCtx, ok := c.Get("share"); ok { + if userCtx, ok := c.Get("user"); ok { + share := shareCtx.(*model.Share) + user := userCtx.(*model.User) + + // 检查用户是否可以下载此分享的文件 + err := share.CanBeDownloadBy(user) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(), + nil)) + c.Abort() + return + } + + // 分享是否已解锁 + if share.Password != "" { + sessionKey := fmt.Sprintf("share_unlock_%d", share.ID) + unlocked := util.GetSession(c, sessionKey) != nil + if !unlocked { + c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, + "无权访问此分享", nil)) + c.Abort() + return + } + } + + // 对积分、下载次数进行更新 + err = share.DownloadBy(user, c) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(), + nil)) + c.Abort() + return + } + + c.Next() + return + } + } + c.Abort() + } +} diff --git a/models/folder.go b/models/folder.go index efc3dd0..96cf516 100644 --- a/models/folder.go +++ b/models/folder.go @@ -13,7 +13,7 @@ type Folder struct { // 表字段 gorm.Model Name string `gorm:"unique_index:idx_only_one_name"` - ParentID uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` + ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` OwnerID uint `gorm:"index:owner_id"` // 数据库忽略字段 @@ -192,7 +192,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6 // 顶级目录直接指向新的目的目录 if folder.ID == folderID { newID = dstFolder.ID - } else if IDCache, ok := newIDCache[folder.ParentID]; ok { + } else if IDCache, ok := newIDCache[*folder.ParentID]; ok { newID = IDCache } else { util.Log().Warning("无法取得新的父目录:%d", folder.ParentID) @@ -202,7 +202,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6 // 插入新的目录记录 oldID := folder.ID folder.Model = gorm.Model{} - folder.ParentID = newID + folder.ParentID = &newID if err = DB.Create(&folder).Error; err != nil { return size, err } @@ -262,6 +262,40 @@ func (folder *Folder) Rename(new string) error { return nil } +// CopyChildFrom 将给定文件和拷贝至自身,并更改所有者ID +func (folder *Folder) CopyChildFrom(folders []Folder, files []File) error { + // 开启事务 + tx := DB.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + // 记录文件父目录对应复制的新目录ID + var newParent = make(map[uint]uint, len(folders)) + + // TODO 复制目录结构 + + // 复制子文件 + for _, file := range files { + file.ID = 0 + file.UserID = folder.OwnerID + if newParentID, ok := newParent[file.FolderID]; ok { + file.FolderID = newParentID + } else { + file.FolderID = folder.ID + } + if err := tx.Create(&file).Error; err != nil { + tx.Rollback() + return err + } + } + + return tx.Commit().Error + +} + /* 实现 FileInfo.FileInfo 接口 TODO 测试 diff --git a/models/folder_test.go b/models/folder_test.go index d41a8ba..1996c6d 100644 --- a/models/folder_test.go +++ b/models/folder_test.go @@ -515,7 +515,6 @@ func TestFolder_FileInfoInterface(t *testing.T) { UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), }, Name: "test_name", - ParentID: 0, OwnerID: 0, Position: "/test", } diff --git a/models/init.go b/models/init.go index bafc977..c782745 100644 --- a/models/init.go +++ b/models/init.go @@ -47,7 +47,7 @@ func Init() { // Debug模式下,输出所有 SQL 日志 if conf.SystemConfig.Debug { - db.LogMode(false) + db.LogMode(true) } //db.SetLogger(util.Log()) diff --git a/models/user.go b/models/user.go index 1b0299c..8f0947f 100644 --- a/models/user.go +++ b/models/user.go @@ -60,7 +60,7 @@ type UserOption struct { // Root 获取用户的根目录 func (user *User) Root() (*Folder, error) { var folder Folder - err := DB.Where("parent_id = 0 AND owner_id = ?", user.ID).First(&folder).Error + err := DB.Where("parent_id is NULL AND owner_id = ?", user.ID).First(&folder).Error return &folder, err } diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go index 9fef4c3..5c3bba6 100644 --- a/pkg/filesystem/manage.go +++ b/pkg/filesystem/manage.go @@ -345,7 +345,7 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) erro // 创建目录 newFolder := model.Folder{ Name: dir, - ParentID: parent.ID, + ParentID: &parent.ID, OwnerID: fs.User.ID, } _, err := newFolder.Create() @@ -355,3 +355,34 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) erro } return nil } + +// SaveTo 将别人分享的文件转存到目标路径下 +// TODO 测试 +func (fs *FileSystem) SaveTo(ctx context.Context, path string) error { + // 获取父目录 + isExist, folder := fs.IsPathExist(path) + if !isExist { + return ErrPathNotExist + } + + // TODO 列目录 + + // 计算要复制的总大小 + var totalSize uint64 + for _, file := range fs.FileTarget { + totalSize += file.Size + } + + // 扣除用户容量 + if !fs.User.IncreaseStorage(totalSize) { + return ErrInsufficientCapacity + } + + err := folder.CopyChildFrom(fs.DirTarget, fs.FileTarget) + if err != nil { + fs.User.DeductionStorage(totalSize) + return ErrFileExisted.WithError(err) + } + + return nil +} diff --git a/routers/controllers/share.go b/routers/controllers/share.go index 25199f0..eb326a2 100644 --- a/routers/controllers/share.go +++ b/routers/controllers/share.go @@ -90,3 +90,14 @@ func GetShareDocPreview(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SaveShare 转存他人分享 +func SaveShare(c *gin.Context) { + var service share.SingleFileService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.SaveToMyFile(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 00668f7..db4c59f 100644 --- a/routers/router.go +++ b/routers/router.go @@ -168,18 +168,31 @@ func InitMasterRouter() *gin.Engine { } // 分享相关 - share := v3.Group("share") + share := v3.Group("share", middleware.ShareAvailable()) { // 获取分享 share.GET("info/:id", controllers.GetShare) // 创建文件下载会话 - share.POST("download/:id", controllers.GetShareDownload) + share.POST("download/:id", + middleware.BeforeShareDownload(), + controllers.GetShareDownload, + ) // 预览分享文件 - share.GET("preview/:id", controllers.PreviewShare) + share.GET("preview/:id", + middleware.ShareCanPreview(), + middleware.BeforeShareDownload(), + controllers.PreviewShare, + ) // 取得Office文档预览地址 - share.GET("doc/:id", controllers.GetShareDocPreview) + share.GET("doc/:id", middleware.ShareCanPreview(), + middleware.BeforeShareDownload(), + controllers.GetShareDocPreview, + ) // 获取文本文件内容 - share.GET("content/:id", controllers.PreviewShareText) + share.GET("content/:id", + middleware.BeforeShareDownload(), + controllers.PreviewShareText, + ) } // 需要登录保护的 @@ -256,6 +269,12 @@ func InitMasterRouter() *gin.Engine { { // 创建新分享 share.POST("", controllers.CreateShare) + // 转存他人分享 + share.POST("save/:id", + middleware.ShareAvailable(), + middleware.BeforeShareDownload(), + controllers.SaveShare, + ) } } diff --git a/service/share/manage.go b/service/share/manage.go index 2247ca8..65c01a2 100644 --- a/service/share/manage.go +++ b/service/share/manage.go @@ -22,7 +22,8 @@ type ShareCreateService struct { // Create 创建新分享 func (service *ShareCreateService) Create(c *gin.Context) serializer.Response { - user := currentUser(c) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) // 是否拥有权限 if !user.Group.ShareEnabled { @@ -82,13 +83,3 @@ func (service *ShareCreateService) Create(c *gin.Context) serializer.Response { } } - -func currentUser(c *gin.Context) *model.User { - var user *model.User - if userCtx, ok := c.Get("user"); ok { - user = userCtx.(*model.User) - } else { - user = model.NewAnonymousUser() - } - return user -} diff --git a/service/share/visit.go b/service/share/visit.go index 9278612..ee941a5 100644 --- a/service/share/visit.go +++ b/service/share/visit.go @@ -2,7 +2,6 @@ package share import ( "context" - "errors" "fmt" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem" @@ -25,11 +24,10 @@ type SingleFileService struct { // Get 获取分享内容 func (service *ShareGetService) Get(c *gin.Context) serializer.Response { - user := currentUser(c) - share := model.GetShareByHashID(c.Param("id")) - if share == nil || !share.IsAvailable() { - return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil) - } + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) // 是否已解锁 unlocked := true @@ -62,17 +60,10 @@ func (service *ShareGetService) Get(c *gin.Context) serializer.Response { // CreateDownloadSession 创建下载会话 func (service *SingleFileService) CreateDownloadSession(c *gin.Context) serializer.Response { - user := currentUser(c) - share := model.GetShareByHashID(c.Param("id")) - if share == nil || !share.IsAvailable() { - return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil) - } - - // 检查用户是否可以下载此分享的文件 - err := CheckBeforeGetShare(share, user, c) - if err != nil { - return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil) - } + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) // 创建文件系统 fs, err := filesystem.NewFileSystem(user) @@ -102,21 +93,8 @@ func (service *SingleFileService) CreateDownloadSession(c *gin.Context) serializ // PreviewContent 预览文件,需要登录会话, isText - 是否为文本文件,文本文件会 // 强制经由服务端中转 func (service *SingleFileService) PreviewContent(ctx context.Context, c *gin.Context, isText bool) serializer.Response { - user := currentUser(c) - share := model.GetShareByHashID(c.Param("id")) - if share == nil || !share.IsAvailable() { - return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil) - } - - if !share.PreviewEnabled { - return serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览", nil) - } - - // 检查用户是否可以下载此分享的文件 - err := CheckBeforeGetShare(share, user, c) - if err != nil { - return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil) - } + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) // 用于调下层service ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.GetSource()) @@ -129,21 +107,8 @@ func (service *SingleFileService) PreviewContent(ctx context.Context, c *gin.Con // CreateDocPreviewSession 创建Office预览会话,返回预览地址 func (service *SingleFileService) CreateDocPreviewSession(c *gin.Context) serializer.Response { - user := currentUser(c) - share := model.GetShareByHashID(c.Param("id")) - if share == nil || !share.IsAvailable() { - return serializer.Err(serializer.CodeNotFound, "分享不存在或已被取消", nil) - } - - if !share.PreviewEnabled { - return serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览", nil) - } - - // 检查用户是否可以下载此分享的文件 - err := CheckBeforeGetShare(share, user, c) - if err != nil { - return serializer.Err(serializer.CodeNoPermissionErr, err.Error(), nil) - } + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) // 用于调下层service ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, share.GetSource()) @@ -154,28 +119,35 @@ func (service *SingleFileService) CreateDocPreviewSession(c *gin.Context) serial return subService.CreateDocPreviewSession(ctx, c) } -// CheckBeforeGetShare 获取分享内容/下载前进行的一系列检查 -func CheckBeforeGetShare(share *model.Share, user *model.User, c *gin.Context) error { - // 检查用户是否可以下载此分享的文件 - err := share.CanBeDownloadBy(user) +// SaveToMyFile 将此分享转存到自己的网盘 +func (service *SingleFileService) SaveToMyFile(c *gin.Context) serializer.Response { + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) + + // 不能转存自己的文件 + if share.UserID == user.ID { + return serializer.Err(serializer.CodePolicyNotAllowed, "不能转存自己的分享", nil) + } + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(user) if err != nil { - return err + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } + defer fs.Recycle() - // 分享是否已解锁 - if share.Password != "" { - sessionKey := fmt.Sprintf("share_unlock_%d", share.ID) - unlocked := util.GetSession(c, sessionKey) != nil - if !unlocked { - return errors.New("无权访问此分享") - } + // 重设文件系统处理目标为源文件 + err = fs.SetTargetByInterface(share.GetSource()) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, "源文件不存在", err) } - // 对积分、下载次数进行更新 - err = share.DownloadBy(user, c) + err = fs.SaveTo(context.Background(), service.Path) if err != nil { - return err + return serializer.Err(serializer.CodeNotSet, err.Error(), err) } - return nil + return serializer.Response{} }