diff --git a/models/file.go b/models/file.go index b7b3707f..bc879b2d 100644 --- a/models/file.go +++ b/models/file.go @@ -126,11 +126,21 @@ func (folder *Folder) GetChildFiles() ([]File, error) { // GetFilesByIDs 根据文件ID批量获取文件, // UID为0表示忽略用户,只根据文件ID检索 -func GetFilesByIDs(ids []uint, uid uint) ([]File, error) { - return GetFilesByIDsFromTX(DB, ids, uid) +func GetFilesByIDs(ids []uint, uid, gid uint) ([]File, error) { + uResult, err := GetFilesByIDsFromTX(DB, ids, int(uid)) + if err == nil && len(uResult) > 0 { + return uResult, nil + } + + gResult, err := GetFilesByIDsFromTX(DB, ids, -int(gid)) + if err == nil && len(gResult) > 0 { + return gResult, nil + } + + return []File{}, err } -func GetFilesByIDsFromTX(tx *gorm.DB, ids []uint, uid uint) ([]File, error) { +func GetFilesByIDsFromTX(tx *gorm.DB, ids []uint, uid int) ([]File, error) { var files []File var result *gorm.DB if uid == 0 { @@ -295,9 +305,13 @@ func GetFilesByParentIDs(ids []uint, uid uint) ([]File, error) { } // GetFilesByUploadSession 查找上传会话对应的文件 -func GetFilesByUploadSession(sessionID string, uid uint) (*File, error) { +func GetFilesByUploadSession(sessionID string, user *User) (*File, error) { file := File{} - result := DB.Where("user_id = ? and upload_session_id = ?", uid, sessionID).Find(&file) + result := DB.Where("user_id = ? and upload_session_id = ?", user.ID, sessionID).Find(&file) + if result.Error != nil { + result = DB.Where("user_id = ? and upload_session_id = ?", -int(user.GroupID), sessionID).Find(&file) + } + return &file, result.Error } diff --git a/models/share.go b/models/share.go index 750eb48e..3a13afbd 100644 --- a/models/share.go +++ b/models/share.go @@ -118,7 +118,7 @@ func (share *Share) SourceFolder() *Folder { // SourceFile 获取源文件 func (share *Share) SourceFile() *File { if share.File.ID == 0 { - files, _ := GetFilesByIDs([]uint{share.SourceID}, share.UserID) + files, _ := GetFilesByIDs([]uint{share.SourceID}, share.UserID, 0) if len(files) > 0 { share.File = files[0] } diff --git a/models/source_link.go b/models/source_link.go index 49dfea28..94c97188 100644 --- a/models/source_link.go +++ b/models/source_link.go @@ -32,7 +32,7 @@ func (s *SourceLink) Link() (string, error) { func GetSourceLinkByID(id interface{}) (*SourceLink, error) { link := &SourceLink{} result := DB.Where("id = ?", id).First(link) - files, _ := GetFilesByIDs([]uint{link.FileID}, 0) + files, _ := GetFilesByIDs([]uint{link.FileID}, 0, 0) if len(files) > 0 { link.File = files[0] } diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 78fc45fd..8bf9eb9b 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -33,7 +33,7 @@ func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, } // 查找待压缩文件 - files, err := model.GetFilesByIDs(fileIDs, fs.User.ID) + files, err := model.GetFilesByIDs(fileIDs, fs.User.ID, fs.User.GroupID) if err != nil && len(fileIDs) != 0 { return ErrDBListObjects } diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 21ad8b09..69b40729 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -53,10 +53,17 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fs } uploadInfo := file.Info() + + var id int + if parent.OwnerID < 0 { + id = parent.OwnerID + } else { + id = int(fs.User.ID) + } newFile := model.File{ Name: uploadInfo.FileName, SourceName: uploadInfo.SavePath, - UserID: int(fs.User.ID), + UserID: id, Size: uploadInfo.Size, FolderID: parent.ID, PolicyID: fs.Policy.ID, @@ -325,20 +332,20 @@ func (fs *FileSystem) ResetFileIfNotExist(ctx context.Context, path string) erro // ResetFileIfNotExist 重设当前目标文件为 id,如果当前目标为空 func (fs *FileSystem) resetFileIDIfNotExist(ctx context.Context, id uint) error { - // 找到文件 - if len(fs.FileTarget) == 0 { - file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID) - if err != nil || len(file) == 0 { + // 如果上下文限制了父目录,则进行检查 + if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok { + if parent.ID != fs.FileTarget[0].FolderID { return ErrObjectNotExist } - fs.FileTarget = []model.File{file[0]} } - // 如果上下文限制了父目录,则进行检查 - if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok { - if parent.ID != fs.FileTarget[0].FolderID { + // 找到文件 + if len(fs.FileTarget) == 0 { + file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID, fs.User.GroupID) + if err != nil || len(file) == 0 { return ErrObjectNotExist } + fs.FileTarget = []model.File{file[0]} } // 将当前存储策略重设为文件使用的 diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 1e14fa81..2452c0ee 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -262,7 +262,7 @@ func (fs *FileSystem) SetTargetDir(dirs *[]model.Folder) { // SetTargetFileByIDs 根据文件ID设置目标文件,忽略用户ID func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error { - files, err := model.GetFilesByIDs(ids, 0) + files, err := model.GetFilesByIDs(ids, 0, 0) if err != nil || len(files) == 0 { return ErrFileExisted.WithError(err) } diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go index 2534dcc1..4b3fbedc 100644 --- a/pkg/filesystem/manage.go +++ b/pkg/filesystem/manage.go @@ -26,7 +26,7 @@ func (fs *FileSystem) Rename(ctx context.Context, dir, file []uint, new string) // 如果源对象是文件 if len(file) > 0 { - fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID) + fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID, fs.User.GroupID) if err != nil || len(fileObject) == 0 { return ErrPathNotExist } @@ -257,7 +257,7 @@ func (fs *FileSystem) ListDeleteDirs(ctx context.Context, ids []uint) error { // ListDeleteFiles 根据给定的路径列出要删除的文件 func (fs *FileSystem) ListDeleteFiles(ctx context.Context, ids []uint) error { - files, err := model.GetFilesByIDs(ids, fs.User.ID) + files, err := model.GetFilesByIDs(ids, fs.User.ID, fs.User.GroupID) if err != nil { return ErrDBListObjects.WithError(err) } diff --git a/service/admin/file.go b/service/admin/file.go index 3c0cc773..c1879340 100644 --- a/service/admin/file.go +++ b/service/admin/file.go @@ -87,7 +87,7 @@ func (service *ListFolderService) List(c *gin.Context) serializer.Response { // Delete 删除文件 func (service *FileBatchService) Delete(c *gin.Context) serializer.Response { - files, err := model.GetFilesByIDs(service.ID, 0) + files, err := model.GetFilesByIDs(service.ID, 0, 0) if err != nil { return serializer.DBErr("Failed to list files for deleting", err) } @@ -141,7 +141,7 @@ func (service *FileBatchService) Delete(c *gin.Context) serializer.Response { // Get 预览文件 func (service *FileService) Get(c *gin.Context) serializer.Response { - file, err := model.GetFilesByIDs([]uint{service.ID}, 0) + file, err := model.GetFilesByIDs([]uint{service.ID}, 0, 0) if err != nil { return serializer.Err(serializer.CodeFileNotFound, "", err) } diff --git a/service/callback/upload.go b/service/callback/upload.go index 0dd7924c..0ad004db 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -120,7 +120,7 @@ func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer. uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) // 查找上传会话创建的占位文件 - file, err := model.GetFilesByUploadSession(uploadSession.Key, fs.User.ID) + file, err := model.GetFilesByUploadSession(uploadSession.Key, fs.User) if err != nil { return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) } diff --git a/service/explorer/file.go b/service/explorer/file.go index 1c9d870d..27e93eb7 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -433,7 +433,7 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se // 取得现有文件 fileID, _ := c.Get("object_id") - originFile, _ := model.GetFilesByIDs([]uint{fileID.(uint)}, fs.User.ID) + originFile, _ := model.GetFilesByIDs([]uint{fileID.(uint)}, fs.User.ID, fs.User.GroupID) if len(originFile) == 0 { return serializer.Err(serializer.CodeFileNotFound, "", nil) } @@ -485,7 +485,7 @@ func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer. } res := make([]serializer.Sources, 0, len(s.Raw().Items)) - files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID) + files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID, fs.User.GroupID) if err != nil || len(files) == 0 { return serializer.Err(serializer.CodeFileNotFound, "", err) } diff --git a/service/explorer/objects.go b/service/explorer/objects.go index 1c3c45a2..4f92e2be 100644 --- a/service/explorer/objects.go +++ b/service/explorer/objects.go @@ -381,7 +381,7 @@ func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Cont return serializer.Err(serializer.CodeNotFound, "", err) } - file, err := model.GetFilesByIDs([]uint{res}, user.ID) + file, err := model.GetFilesByIDs([]uint{res}, user.ID, user.GroupID) if err != nil { return serializer.DBErr("Failed to query file records", err) } diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 0c26c26c..221f0b7a 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -94,7 +94,7 @@ func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) s } // 查找上传会话创建的占位文件 - file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) + file, err := model.GetFilesByUploadSession(service.ID, fs.User) if err != nil { return serializer.Err(serializer.CodeUploadSessionExpired, "", err) } @@ -232,7 +232,7 @@ func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context) defer fs.Recycle() // 查找需要删除的上传会话的占位文件 - file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) + file, err := model.GetFilesByUploadSession(service.ID, fs.User) if err != nil { return serializer.Err(serializer.CodeUploadSessionExpired, "", err) } diff --git a/service/share/manage.go b/service/share/manage.go index 9daccdb4..0000a7a3 100644 --- a/service/share/manage.go +++ b/service/share/manage.go @@ -101,7 +101,7 @@ func (service *ShareCreateService) Create(c *gin.Context) serializer.Response { sourceName = folder[0].Name } } else { - file, err := model.GetFilesByIDs([]uint{sourceID}, user.ID) + file, err := model.GetFilesByIDs([]uint{sourceID}, user.ID, user.GroupID) if err != nil || len(file) == 0 { exist = false } else {