diff --git a/models/file.go b/models/file.go index 9725a38..e78f6d0 100644 --- a/models/file.go +++ b/models/file.go @@ -3,6 +3,7 @@ package model import ( "encoding/gob" "encoding/json" + "errors" "path" "time" @@ -200,10 +201,35 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) { } -// DeleteFileByIDs 根据给定ID批量删除文件记录 -func DeleteFileByIDs(ids []uint) error { - result := DB.Where("id in (?)", ids).Unscoped().Delete(&File{}) - return result.Error +// DeleteFiles 批量删除文件记录并归还容量 +func DeleteFiles(files []*File, uid uint) error { + tx := DB.Begin() + user := &User{} + user.ID = uid + var size uint64 + for _, file := range files { + if file.UserID != uid { + tx.Rollback() + return errors.New("User id not consistent") + } + + result := tx.Unscoped().Delete(file) + if result.RowsAffected != 0 { + size += file.Size + } + + if result.Error != nil { + tx.Rollback() + return result.Error + } + } + + if err := user.ChangeStorage(tx, "-", size); err != nil { + tx.Rollback() + return err + } + + return tx.Commit().Error } // GetFilesByParentIDs 根据父目录ID查找文件 @@ -232,7 +258,29 @@ func (file *File) UpdatePicInfo(value string) error { // UpdateSize 更新文件的大小信息 func (file *File) UpdateSize(value uint64) error { - return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value).Error + tx := DB.Begin() + var sizeDelta uint64 + operator := "+" + user := User{} + user.ID = file.UserID + if value > file.Size { + sizeDelta = value - file.Size + } else { + operator = "-" + sizeDelta = file.Size - value + } + + if res := tx.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value); res.Error != nil { + tx.Rollback() + return res.Error + } + + if err := user.ChangeStorage(tx, operator, sizeDelta); err != nil { + tx.Rollback() + return err + } + + return tx.Commit().Error } // UpdateSourceName 更新文件的源文件名 diff --git a/models/user.go b/models/user.go index 8045d27..b311167 100644 --- a/models/user.go +++ b/models/user.go @@ -89,6 +89,11 @@ func (user *User) IncreaseStorage(size uint64) bool { return false } +// ChangeStorage 更新用户容量 +func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error { + return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error +} + // IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量 func (user *User) IncreaseStorageWithoutCheck(size uint64) { if size == 0 { diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index cf30d2c..826e101 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -133,19 +133,15 @@ func HookValidateCapacityWithoutIncrease(ctx context.Context, fs *FileSystem, fi return nil } -// HookChangeCapacity 根据原有文件和新文件的大小更新用户容量 -func HookChangeCapacity(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { +// HookValidateCapacityDiff 根据原有文件和新文件的大小验证用户容量 +func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { originFile := ctx.Value(fsctx.FileModelCtx).(model.File) newFileSize := newFile.Info().Size if newFileSize > originFile.Size { - if !fs.ValidateCapacity(ctx, newFileSize-originFile.Size) { - return ErrInsufficientCapacity - } - return nil + return HookValidateCapacityWithoutIncrease(ctx, fs, newFile) } - fs.User.DeductionStorage(originFile.Size - newFileSize) return nil } diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go index b6e66ad..ddf83e6 100644 --- a/pkg/filesystem/manage.go +++ b/pkg/filesystem/manage.go @@ -122,15 +122,12 @@ func (fs *FileSystem) Move(ctx context.Context, dirs, files []uint, src, dst str // Delete 递归删除对象, force 为 true 时强制删除文件记录,忽略物理删除是否成功 func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force bool) error { - // 已删除的总容量,map用于去重 - var deletedStorage = make(map[uint]uint64) - var totalStorage = make(map[uint]uint64) // 已删除的文件ID - var deletedFileIDs = make([]uint, 0, len(fs.FileTarget)) + var deletedFiles = make([]*model.File, 0, len(fs.FileTarget)) // 删除失败的文件的父目录ID // 所有文件的ID - var allFileIDs = make([]uint, 0, len(fs.FileTarget)) + var allFiles = make([]*model.File, 0, len(fs.FileTarget)) // 列出要删除的目录 if len(dirs) > 0 { @@ -164,39 +161,35 @@ func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force bool for i := 0; i < len(fs.FileTarget); i++ { if !util.ContainsString(failed[fs.FileTarget[i].PolicyID], fs.FileTarget[i].SourceName) { // 已成功删除的文件 - deletedFileIDs = append(deletedFileIDs, fs.FileTarget[i].ID) - deletedStorage[fs.FileTarget[i].ID] = fs.FileTarget[i].Size + deletedFiles = append(deletedFiles, &fs.FileTarget[i]) } + // 全部文件 - totalStorage[fs.FileTarget[i].ID] = fs.FileTarget[i].Size - allFileIDs = append(allFileIDs, fs.FileTarget[i].ID) + allFiles = append(allFiles, &fs.FileTarget[i]) } // 如果强制删除,则将全部文件视为删除成功 if force { - deletedFileIDs = allFileIDs - deletedStorage = totalStorage + deletedFiles = allFiles } // 删除文件记录 - err = model.DeleteFileByIDs(deletedFileIDs) + err = model.DeleteFiles(deletedFiles, fs.User.ID) if err != nil { return ErrDBDeleteObjects.WithError(err) } // 删除文件记录对应的分享记录 // TODO 先取消分享再删除文件 - model.DeleteShareBySourceIDs(deletedFileIDs, false) - - // 归还容量 - var total uint64 - for _, value := range deletedStorage { - total += value + deletedFileIDs := make([]uint, len(deletedFiles)) + for k, file := range deletedFiles { + deletedFileIDs[k] = file.ID } - fs.User.DeductionStorage(total) + + model.DeleteShareBySourceIDs(deletedFileIDs, false) // 如果文件全部删除成功,继续删除目录 - if len(deletedFileIDs) == len(allFileIDs) { + if len(deletedFiles) == len(allFiles) { var allFolderIDs = make([]uint, 0, len(fs.DirTarget)) for _, value := range fs.DirTarget { allFolderIDs = append(allFolderIDs, value.ID) @@ -210,7 +203,7 @@ func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force bool model.DeleteShareBySourceIDs(allFolderIDs, true) } - if notDeleted := len(fs.FileTarget) - len(deletedFileIDs); notDeleted > 0 { + if notDeleted := len(fs.FileTarget) - len(deletedFiles); notDeleted > 0 { return serializer.NewError( serializer.CodeNotFullySuccess, fmt.Sprintf("有 %d 个文件未能成功删除", notDeleted), diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 10cd822..8d6ff56 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -349,15 +349,13 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst fs.Use("BeforeUpload", filesystem.HookResetPolicy) fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("BeforeUpload", filesystem.HookChangeCapacity) + fs.Use("BeforeUpload", filesystem.HookValidateCapacityDiff) fs.Use("AfterUploadCanceled", filesystem.HookCleanFileContent) fs.Use("AfterUploadCanceled", filesystem.HookClearFileSize) - fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) fs.Use("AfterUpload", filesystem.GenericAfterUpdate) fs.Use("AfterValidateFailed", filesystem.HookCleanFileContent) fs.Use("AfterValidateFailed", filesystem.HookClearFileSize) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) ctx = context.WithValue(ctx, fsctx.FileModelCtx, *originFile) } else { // 给文件系统分配钩子 diff --git a/service/explorer/file.go b/service/explorer/file.go index fffd2d9..622257b 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -405,14 +405,12 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se // 给文件系统分配钩子 fs.Use("BeforeUpload", filesystem.HookResetPolicy) fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("BeforeUpload", filesystem.HookChangeCapacity) + fs.Use("BeforeUpload", filesystem.HookValidateCapacityDiff) fs.Use("AfterUploadCanceled", filesystem.HookCleanFileContent) fs.Use("AfterUploadCanceled", filesystem.HookClearFileSize) - fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) fs.Use("AfterUpload", filesystem.GenericAfterUpdate) fs.Use("AfterValidateFailed", filesystem.HookCleanFileContent) fs.Use("AfterValidateFailed", filesystem.HookClearFileSize) - fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) // 执行上传 uploadCtx = context.WithValue(uploadCtx, fsctx.FileModelCtx, originFile[0])