diff --git a/middleware/auth_test.go b/middleware/auth_test.go index a2e1b02..976215a 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -216,7 +216,7 @@ func TestRemoteCallbackAuth(t *testing.T) { "callback_testCallBackRemote", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 513, VirtualPath: "/", }, 0, @@ -225,7 +225,7 @@ func TestRemoteCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) c, _ := gin.CreateTestContext(rec) @@ -260,7 +260,7 @@ func TestRemoteCallbackAuth(t *testing.T) { "callback_testCallBackRemote", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 550, VirtualPath: "/", }, 0, @@ -286,7 +286,7 @@ func TestRemoteCallbackAuth(t *testing.T) { "callback_testCallBackRemote", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 514, VirtualPath: "/", }, 0, @@ -295,7 +295,7 @@ func TestRemoteCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[514]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) c, _ := gin.CreateTestContext(rec) @@ -339,7 +339,7 @@ func TestQiniuCallbackAuth(t *testing.T) { "callback_testCallBackQiniu", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 515, VirtualPath: "/", }, 0, @@ -348,7 +348,7 @@ func TestQiniuCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[515]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -371,7 +371,7 @@ func TestQiniuCallbackAuth(t *testing.T) { "callback_testCallBackQiniu", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 516, VirtualPath: "/", }, 0, @@ -380,7 +380,7 @@ func TestQiniuCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[516]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -420,7 +420,7 @@ func TestOSSCallbackAuth(t *testing.T) { "callback_testCallBackOSS", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 517, VirtualPath: "/", }, 0, @@ -429,7 +429,7 @@ func TestOSSCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[517]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -452,7 +452,7 @@ func TestOSSCallbackAuth(t *testing.T) { "callback_TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 518, VirtualPath: "/", }, 0, @@ -461,7 +461,7 @@ func TestOSSCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[518]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -506,7 +506,7 @@ func TestUpyunCallbackAuth(t *testing.T) { "callback_testCallBackUpyun", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 509, VirtualPath: "/", }, 0, @@ -515,7 +515,7 @@ func TestUpyunCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[519]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -534,7 +534,7 @@ func TestUpyunCallbackAuth(t *testing.T) { "callback_testCallBackUpyun", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 510, VirtualPath: "/", }, 0, @@ -543,7 +543,7 @@ func TestUpyunCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[520]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -563,7 +563,7 @@ func TestUpyunCallbackAuth(t *testing.T) { "callback_testCallBackUpyun", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 511, VirtualPath: "/", }, 0, @@ -572,7 +572,7 @@ func TestUpyunCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[521]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) @@ -592,7 +592,7 @@ func TestUpyunCallbackAuth(t *testing.T) { "callback_testCallBackUpyun", serializer.UploadSession{ UID: 1, - PolicyID: 2, + PolicyID: 512, VirtualPath: "/", }, 0, @@ -601,7 +601,7 @@ func TestUpyunCallbackAuth(t *testing.T) { mock.ExpectQuery("SELECT(.+)users(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[522]")) mock.ExpectQuery("SELECT(.+)policies(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) diff --git a/models/policy.go b/models/policy.go index bbcdafc..bd35f5a 100644 --- a/models/policy.go +++ b/models/policy.go @@ -77,9 +77,9 @@ func (policy *Policy) AfterFind() (err error) { // 解析存储策略设置到OptionsSerialized if policy.Options != "" { err = json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized) - if policy.OptionsSerialized.FileType == nil { - policy.OptionsSerialized.FileType = []string{} - } + } + if policy.OptionsSerialized.FileType == nil { + policy.OptionsSerialized.FileType = []string{} } return err diff --git a/models/policy_test.go b/models/policy_test.go index c253877..bf30251 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -14,6 +14,7 @@ import ( func TestGetPolicyByID(t *testing.T) { asserts := assert.New(t) + cache.Deletes([]string{"22", "23"}, "policy_") // 缓存未命中 { rows := sqlmock.NewRows([]string{"name", "type", "options"}). diff --git a/models/user_test.go b/models/user_test.go index 1dce113..b552876 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -12,7 +12,7 @@ import ( func TestGetUserByID(t *testing.T) { asserts := assert.New(t) - + cache.Deletes([]string{"1"}, "policy_") //找到用户时 userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). AddRow(1, nil, "admin@cloudreve.org", "{}", 1) @@ -104,6 +104,7 @@ func TestNewUser(t *testing.T) { func TestUser_AfterFind(t *testing.T) { asserts := assert.New(t) + cache.Deletes([]string{"1"}, "policy_") policyRows := sqlmock.NewRows([]string{"id", "name"}). AddRow(1, "默认存储策略") @@ -198,6 +199,7 @@ func TestUser_GetRemainingCapacity(t *testing.T) { func TestUser_DeductionCapacity(t *testing.T) { asserts := assert.New(t) + cache.Deletes([]string{"1"}, "policy_") userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}). AddRow(1, nil, 0, "{}", 1) mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index f1a294a..e10fc21 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -259,7 +259,7 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst string) error { rawPath := util.FormSlash(f.Name) savePath := path.Join(dst, rawPath) // 路径是否合法 - if !strings.HasPrefix(savePath, path.Clean(dst)+"/") { + if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) { return fmt.Errorf("%s: illegal file path", f.Name) } diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go index 2f28cc3..c7c8d25 100644 --- a/pkg/filesystem/archive_test.go +++ b/pkg/filesystem/archive_test.go @@ -2,12 +2,19 @@ package filesystem import ( "context" + "errors" "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/request" + "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io" + "os" + "strings" "testing" ) @@ -107,3 +114,146 @@ func TestFileSystem_Compress(t *testing.T) { } } + +type MockNopRSC string + +func (m MockNopRSC) Read(b []byte) (int, error) { + return 0, errors.New("read error") +} + +func (m MockNopRSC) Seek(n int64, offset int) (int64, error) { + return 0, errors.New("read error") +} + +func (m MockNopRSC) Close() error { + return errors.New("read error") +} + +type MockRSC struct { + rs io.ReadSeeker +} + +func (m MockRSC) Read(b []byte) (int, error) { + return m.rs.Read(b) +} + +func (m MockRSC) Seek(n int64, offset int) (int64, error) { + return m.rs.Seek(n, offset) +} + +func (m MockRSC) Close() error { + return nil +} + +func TestFileSystem_Decompress(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{Model: gorm.Model{ID: 1}}, + } + + // 压缩文件不存在 + { + // 查找根目录 + mock.ExpectQuery("SELECT(.+)folders(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "/")) + // 查找压缩文件,未找到 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + err := fs.Decompress(ctx, "/1.zip", "/") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + } + + // 无法下载压缩文件 + { + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(request.NopRSCloser{}, errors.New("error")) + fs.Handler = testHandler + err := fs.Decompress(ctx, "/1.zip", "/") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.EqualError(err, "error") + } + + // 无法创建临时压缩文件 + { + cache.Set("setting_temp_path", "/tests:", 0) + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(request.NopRSCloser{}, nil) + fs.Handler = testHandler + err := fs.Decompress(ctx, "/1.zip", "/") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Contains(err.Error(), "label syntax") + } + + // 无法写入压缩文件 + { + cache.Set("setting_temp_path", "tests", 0) + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil) + fs.Handler = testHandler + err := fs.Decompress(ctx, "/1.zip", "/") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.EqualError(err, "read error") + } + + // 无效zip文件 + { + cache.Set("setting_temp_path", "tests", 0) + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil) + fs.Handler = testHandler + err := fs.Decompress(ctx, "/1.zip", "/") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.EqualError(err, "zip: not a valid zip file") + } + + // 无法重设上传策略 + { + zipFile, _ := os.Open("tests/test.zip") + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil) + fs.Handler = testHandler + err := fs.Decompress(ctx, "/1.zip", "/") + zipFile.Close() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.True(util.IsEmpty("tests/decompress")) + asserts.EqualError(err, "未知存储策略类型") + } + + // 无法上传,容量不足 + { + cache.Set("setting_max_parallel_transfer", "1", 0) + zipFile, _ := os.Open("tests/test.zip") + fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} + fs.FileTarget[0].Policy.ID = 1 + fs.User.Policy.Type = "mock" + testHandler := new(FileHeaderMock) + testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil) + fs.Handler = testHandler + + err := fs.Decompress(ctx, "/1.zip", "/") + + zipFile.Close() + + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.True(util.IsEmpty("tests/decompress")) + testHandler.AssertExpectations(t) + } +} diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go index 2152f76..260399d 100644 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -637,15 +637,17 @@ func TestClient_SimpleUpload(t *testing.T) { client, _ := NewClient(&model.Policy{}) client.Credential.AccessToken = "AccessToken" client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + cache.Set("setting_onedrive_chunk_retries", "1", 0) - // 请求失败 + // 请求失败,并重试 { client.Credential.ExpiresIn = 0 - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) asserts.Error(err) asserts.Nil(res) } + cache.Set("setting_onedrive_chunk_retries", "0", 0) // 返回未知响应 { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() @@ -664,7 +666,7 @@ func TestClient_SimpleUpload(t *testing.T) { }, }) client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) clientMock.AssertExpectations(t) asserts.Error(err) asserts.Nil(res) @@ -688,7 +690,7 @@ func TestClient_SimpleUpload(t *testing.T) { }, }) client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) clientMock.AssertExpectations(t) asserts.NoError(err) asserts.NotNil(res) @@ -733,6 +735,36 @@ func TestClient_DeleteUploadSession(t *testing.T) { } } +func TestClient_BatchDelete(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 小于20个,失败1个 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), + }, + }) + client.Request = clientMock + res, err := client.BatchDelete(context.Background(), []string{"1", "2", "3", "1", "2"}) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Equal([]string{"2"}, res) + } +} + func TestClient_Delete(t *testing.T) { asserts := assert.New(t) client, _ := NewClient(&model.Policy{}) diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go index 6a3646e..2c5348d 100644 --- a/pkg/filesystem/file_test.go +++ b/pkg/filesystem/file_test.go @@ -131,7 +131,7 @@ func TestFileSystem_GetDownloadContent(t *testing.T) { }, Policy: model.Policy{ Model: gorm.Model{ - ID: 1, + ID: 599, }, }, }, @@ -140,23 +140,26 @@ func TestFileSystem_GetDownloadContent(t *testing.T) { asserts.NoError(err) _ = file.Close() + cache.Deletes([]string{"599"}, "policy_") mock.ExpectQuery("SELECT(.+)"). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 1, "TestFileSystem_GetDownloadContent.txt")) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) // 无限速 + cache.Deletes([]string{"599"}, "policy_") _, err = fs.GetDownloadContent(ctx, "/TestFileSystem_GetDownloadContent.txt") asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) fs.CleanTargets() // 有限速 + cache.Deletes([]string{"599"}, "policy_") mock.ExpectQuery("SELECT(.+)"). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 1, "TestFileSystem_GetDownloadContent.txt")) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) fs.User.Group.SpeedLimit = 1 @@ -346,13 +349,13 @@ func TestFileSystem_GetSource(t *testing.T) { WithArgs(2, 1). WillReturnRows( sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 35, "1.txt"), + AddRow(2, 36, "1.txt"), ) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). WillReturnRows( sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "?", true), + AddRow(36, "?", true), ) sourceURL, err := fs.GetSource(ctx, 2) @@ -375,13 +378,13 @@ func TestFileSystem_GetSource(t *testing.T) { WithArgs(2, 1). WillReturnRows( sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 35, "1.txt"), + AddRow(2, 37, "1.txt"), ) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). WillReturnRows( sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "local", false), + AddRow(37, "local", false), ) sourceURL, err := fs.GetSource(ctx, 2) @@ -403,15 +406,15 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { // 正常 { - err := cache.Deletes([]string{"siteURL"}, "setting_") - err = cache.Deletes([]string{"35"}, "policy_") - err = cache.Deletes([]string{"download_timeout"}, "setting_") + err := cache.Deletes([]string{"35"}, "policy_") + cache.Set("setting_download_timeout", "20", 0) + cache.Set("setting_siteURL", "https://cloudreve.org", 0) asserts.NoError(err) // 查找文件 mock.ExpectQuery("SELECT(.+)"). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). WillReturnRows( @@ -419,8 +422,6 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { AddRow(35, "local", true), ) // 相关设置 - mock.ExpectQuery("SELECT(.+)").WithArgs("download_timeout").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "20")) - mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://cloudreve.org")) downloadURL, err := fs.GetDownloadURL(ctx, "/1.txt", "download_timeout") asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) @@ -457,7 +458,7 @@ func TestFileSystem_GetDownloadURL(t *testing.T) { mock.ExpectQuery("SELECT(.+)"). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 1)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) // 查找上传策略 mock.ExpectQuery("SELECT(.+)"). WillReturnRows( diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 2346f6b..55e3a5e 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -625,3 +625,31 @@ func TestSlaveAfterUpload(t *testing.T) { asserts.NoError(err) } } + +func TestFileSystem_CleanHooks(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{ + User: &model.User{ + Model: gorm.Model{ID: 1}, + }, + Hooks: map[string][]Hook{ + "hook1": []Hook{}, + "hook2": []Hook{}, + "hook3": []Hook{}, + }, + } + + // 清理一个 + { + fs.CleanHooks("hook2") + asserts.Len(fs.Hooks, 2) + asserts.Contains(fs.Hooks, "hook1") + asserts.Contains(fs.Hooks, "hook3") + } + + // 清理全部 + { + fs.CleanHooks("") + asserts.Len(fs.Hooks, 0) + } +} diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go index 6980e27..64b86ad 100644 --- a/pkg/filesystem/manage_test.go +++ b/pkg/filesystem/manage_test.go @@ -149,12 +149,6 @@ func TestFileSystem_CreateDirectory(t *testing.T) { _, err := fs.CreateDirectory(ctx, "/ad/a+?") asserts.Equal(ErrIllegalObjectName, err) - // 父目录不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.Equal(ErrPathNotExist, err) - asserts.NoError(mock.ExpectationsWereMet()) - // 存在同名文件 // 根目录 mock.ExpectQuery("SELECT(.+)"). @@ -204,6 +198,12 @@ func TestFileSystem_CreateDirectory(t *testing.T) { _, err = fs.CreateDirectory(ctx, "/ad/ab") asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) + + // 父目录不存在 + mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + _, err = fs.CreateDirectory(ctx, "/ad") + asserts.Equal(ErrRootProtected, err) + asserts.NoError(mock.ExpectationsWereMet()) } func TestFileSystem_ListDeleteFiles(t *testing.T) { @@ -323,12 +323,11 @@ func TestFileSystem_Delete(t *testing.T) { AddRow(4, "1.txt", "1.txt", 2, 1), ) // 查询顶级的文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "1.txt", "1.txt", 1, 2)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "1.txt", "1.txt", 603, 2)) mock.ExpectQuery("SELECT(.+)files(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(603, "local")) // 删除文件记录 mock.ExpectBegin() mock.ExpectExec("DELETE(.+)files"). @@ -368,14 +367,13 @@ func TestFileSystem_Delete(t *testing.T) { WithArgs(1, 2, 3). WillReturnRows( sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}). - AddRow(4, "1.txt", "1.txt", 2, 1), + AddRow(4, "1.txt", "1.txt", 602, 1), ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 1, 2)) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2)) mock.ExpectQuery("SELECT(.+)files(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(602, "local")) // 删除文件记录 mock.ExpectBegin() mock.ExpectExec("DELETE(.+)"). diff --git a/pkg/filesystem/tests/test.zip b/pkg/filesystem/tests/test.zip new file mode 100644 index 0000000..316212e Binary files /dev/null and b/pkg/filesystem/tests/test.zip differ diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index 41f4fe9..9b0ea1c 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -14,9 +14,11 @@ import ( "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" "io" + "io/ioutil" "net/http" "net/http/httptest" "net/url" + "strings" "testing" ) @@ -200,3 +202,30 @@ func TestFileSystem_GetUploadToken(t *testing.T) { asserts.Error(err) } } + +func TestFileSystem_UploadFromStream(t *testing.T) { + asserts := assert.New(t) + fs := FileSystem{User: &model.User{Model: gorm.Model{ID: 1}}} + ctx := context.Background() + + err := fs.UploadFromStream(ctx, ioutil.NopCloser(strings.NewReader("123")), "/1.txt", 1) + asserts.Error(err) +} + +func TestFileSystem_UploadFromPath(t *testing.T) { + asserts := assert.New(t) + fs := FileSystem{User: &model.User{Policy: model.Policy{Type: "mock"}, Model: gorm.Model{ID: 1}}} + ctx := context.Background() + + // 文件不存在 + { + err := fs.UploadFromPath(ctx, "test/not_exist", "/") + asserts.Error(err) + } + + // 文存在,上传失败 + { + err := fs.UploadFromPath(ctx, "tests/test.zip", "/") + asserts.Error(err) + } +} diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go index e19ce42..87f6df5 100644 --- a/pkg/filesystem/validator_test.go +++ b/pkg/filesystem/validator_test.go @@ -82,6 +82,10 @@ func TestFileSystem_ValidateFileSize(t *testing.T) { asserts.True(fs.ValidateFileSize(ctx, 5)) asserts.True(fs.ValidateFileSize(ctx, 10)) asserts.False(fs.ValidateFileSize(ctx, 11)) + + // 无限制 + fs.User.Policy.MaxSize = 0 + asserts.True(fs.ValidateFileSize(ctx, 11)) } func TestFileSystem_ValidateExtension(t *testing.T) { diff --git a/pkg/util/io_test.go b/pkg/util/io_test.go index 2e095cc..755d203 100644 --- a/pkg/util/io_test.go +++ b/pkg/util/io_test.go @@ -30,3 +30,10 @@ func TestCreatNestedFile(t *testing.T) { asserts.FileExists("test/direct.txt") } } + +func TestIsEmpty(t *testing.T) { + asserts := assert.New(t) + + asserts.False(IsEmpty("")) + asserts.False(IsEmpty("not_exist")) +}