diff --git a/middleware/mock_test.go b/middleware/mock_test.go index 323d247..cc715b1 100644 --- a/middleware/mock_test.go +++ b/middleware/mock_test.go @@ -1,7 +1,36 @@ package middleware -import "testing" +import ( + "github.com/HFO4/cloudreve/pkg/util" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) func TestMockHelper(t *testing.T) { + asserts := assert.New(t) + MockHelperFunc := MockHelper() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("GET", "/test", nil) + // 写入session + { + SessionMock["test"] = "pass" + Session("test")(c) + MockHelperFunc(c) + asserts.Equal("pass", util.GetSession(c, "test").(string)) + } + + // 写入context + { + ContextMock["test"] = "pass" + MockHelperFunc(c) + test, exist := c.Get("test") + asserts.True(exist) + asserts.Equal("pass", test.(string)) + + } } diff --git a/models/file.go b/models/file.go index 160437e..4a057c9 100644 --- a/models/file.go +++ b/models/file.go @@ -127,13 +127,13 @@ func DeleteFileByIDs(ids []uint) error { return result.Error } -// GetRecursiveByPaths 根据给定的文件路径(s)递归查找文件 -func GetRecursiveByPaths(paths []string, uid uint) ([]File, error) { - files := make([]File, 0, len(paths)) - search := util.BuildRegexp(paths, "^", "/", "|") - result := DB.Where("(user_id = ? and dir REGEXP ?) or (user_id = ? and dir in (?))", uid, search, uid, paths).Find(&files) - return files, result.Error -} +//// GetRecursiveByPaths 根据给定的文件路径(s)递归查找文件 +//func GetRecursiveByPaths(paths []string, uid uint) ([]File, error) { +// files := make([]File, 0, len(paths)) +// search := util.BuildRegexp(paths, "^", "/", "|") +// result := DB.Where("(user_id = ? and dir REGEXP ?) or (user_id = ? and dir in (?))", uid, search, uid, paths).Find(&files) +// return files, result.Error +//} // GetFilesByParentIDs 根据父目录ID查找文件 func GetFilesByParentIDs(ids []uint, uid uint) ([]File, error) { diff --git a/models/file_test.go b/models/file_test.go index 94fbba9..4629ccb 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -290,3 +290,20 @@ func TestDeleteFileByIDs(t *testing.T) { asserts.NoError(err) } } + +func TestGetFilesByParentIDs(t *testing.T) { + asserts := assert.New(t) + + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 4, 5, 6). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}). + AddRow(4, "4.txt"). + AddRow(5, "5.txt"). + AddRow(6, "6.txt"), + ) + files, err := GetFilesByParentIDs([]uint{4, 5, 6}, 1) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(files, 3) +} diff --git a/models/folder.go b/models/folder.go index 225a154..c5cba49 100644 --- a/models/folder.go +++ b/models/folder.go @@ -196,7 +196,11 @@ func (folder *Folder) MoveOrCopyFolderTo(dirs []string, dstFolder *Folder, isCop // 复制 // TODO:支持多目录 origin := Folder{} - if DB.Where("position_absolute in (?) and owner_id = ?", fullDirs, folder.OwnerID).Find(&origin).Error != nil { + if DB.Where( + "position_absolute in (?) and owner_id = ?", + fullDirs, + folder.OwnerID, + ).Find(&origin).Error != nil { return 0, errors.New("找不到原始目录") } @@ -220,7 +224,11 @@ func (folder *Folder) MoveOrCopyFolderTo(dirs []string, dstFolder *Folder, isCop } else { // 移动 // 更改顶级要移动目录的父目录指向 - err = DB.Model(Folder{}).Where("position_absolute in (?) and owner_id = ?", fullDirs, folder.OwnerID). + err = DB.Model(Folder{}). + Where("position_absolute in (?) and owner_id = ?", + fullDirs, + folder.OwnerID, + ). Update(map[string]interface{}{ "parent_id": dstFolder.ID, "position": dstFolder.PositionAbsolute, @@ -272,7 +280,10 @@ func (folder *Folder) MoveOrCopyFolderTo(dirs []string, dstFolder *Folder, isCop folder.PositionAbsolute, "", 1), ) toBeMoved[innerIndex].Position = newPosition - toBeMoved[innerIndex].PositionAbsolute = path.Join(newPosition, toBeMoved[innerIndex].Name) + toBeMoved[innerIndex].PositionAbsolute = path.Join( + newPosition, + toBeMoved[innerIndex].Name, + ) toBeMoved[innerIndex].ParentID = newID toBeMoved[innerIndex].Model = gorm.Model{} if err := DB.Create(&toBeMoved[innerIndex]).Error; err != nil { diff --git a/models/folder_test.go b/models/folder_test.go index 0b3efb1..e7563a5 100644 --- a/models/folder_test.go +++ b/models/folder_test.go @@ -107,6 +107,52 @@ func TestGetRecursiveChildFolder(t *testing.T) { } } +func TestGetRecursiveChildFolderSQLite(t *testing.T) { + conf.DatabaseConfig.Type = "sqlite3" + asserts := assert.New(t) + + // 测试目录结构 + // 1 + // 2 3 + // 4 5 6 + + // 查询第一层 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, "/test"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "folder1"), + ) + // 查询第二层 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 1). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}). + AddRow(2, "folder2"). + AddRow(3, "folder3"), + ) + // 查询第三层 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 2, 3). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}). + AddRow(4, "folder4"). + AddRow(5, "folder5"). + AddRow(6, "folder6"), + ) + // 查询第四层 + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 4, 5, 6). + WillReturnRows( + sqlmock.NewRows([]string{"id", "name"}), + ) + + folders, err := GetRecursiveChildFolder([]string{"/test"}, 1, true) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Len(folders, 6) +} + func TestDeleteFolderByIDs(t *testing.T) { asserts := assert.New(t) @@ -131,3 +177,129 @@ func TestDeleteFolderByIDs(t *testing.T) { asserts.NoError(err) } } + +func TestFolder_MoveOrCopyFileTo(t *testing.T) { + asserts := assert.New(t) + // 当前目录 + folder := Folder{ + OwnerID: 1, + PositionAbsolute: "/test", + } + // 目标目录 + dstFolder := Folder{ + Model: gorm.Model{ID: 10}, + PositionAbsolute: "/dst", + } + + // 复制文件 + { + mock.ExpectQuery("SELECT(.+)"). + WithArgs( + "1.txt", + "2.txt", + 1, + "/test", + ).WillReturnRows( + sqlmock.NewRows([]string{"id", "size"}). + AddRow(1, 10). + AddRow(2, 20), + ) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + storage, err := folder.MoveOrCopyFileTo( + []string{"1.txt", "2.txt"}, + &dstFolder, + true, + ) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(uint64(30), storage) + } + + // 复制文件, 检索文件出错 + { + mock.ExpectQuery("SELECT(.+)"). + WithArgs( + "1.txt", + "2.txt", + 1, + "/test", + ).WillReturnError(errors.New("error")) + + storage, err := folder.MoveOrCopyFileTo( + []string{"1.txt", "2.txt"}, + &dstFolder, + true, + ) + asserts.Error(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(uint64(0), storage) + } + + // 复制文件,第二个文件插入出错 + { + mock.ExpectQuery("SELECT(.+)"). + WithArgs( + "1.txt", + "2.txt", + 1, + "/test", + ).WillReturnRows( + sqlmock.NewRows([]string{"id", "size"}). + AddRow(1, 10). + AddRow(2, 20), + ) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + storage, err := folder.MoveOrCopyFileTo( + []string{"1.txt", "2.txt"}, + &dstFolder, + true, + ) + asserts.Error(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(uint64(10), storage) + } + + // 移动文件 成功 + { + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)"). + WithArgs("/dst", 10, sqlmock.AnyArg(), "1.txt", "2.txt", 1, "/test"). + WillReturnResult(sqlmock.NewResult(1, 2)) + mock.ExpectCommit() + storage, err := folder.MoveOrCopyFileTo( + []string{"1.txt", "2.txt"}, + &dstFolder, + false, + ) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Equal(uint64(0), storage) + } + + // 移动文件 出错 + { + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)"). + WithArgs("/dst", 10, sqlmock.AnyArg(), "1.txt", "2.txt", 1, "/test"). + WillReturnError(errors.New("error")) + mock.ExpectRollback() + storage, err := folder.MoveOrCopyFileTo( + []string{"1.txt", "2.txt"}, + &dstFolder, + false, + ) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Equal(uint64(0), storage) + } +} diff --git a/pkg/filesystem/path.go b/pkg/filesystem/path.go index 45f3004..661375c 100644 --- a/pkg/filesystem/path.go +++ b/pkg/filesystem/path.go @@ -56,6 +56,9 @@ func (fs *FileSystem) Copy(ctx context.Context, dirs, files []string, src, dst s newUsedStorage += subFileSizes } + // 扣除容量 + fs.User.IncreaseStorageWithoutCheck(newUsedStorage) + return nil }