From eceee2fc76b24c26b9772a2d334f7ae9851faea5 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Mon, 30 Dec 2019 19:56:01 +0800 Subject: [PATCH] Test: remote callback auth --- middleware/auth_test.go | 144 ++++++++++++++++++++++++++++++++++++++++ models/user.go | 4 +- pkg/auth/auth.go | 9 ++- 3 files changed, 153 insertions(+), 4 deletions(-) diff --git a/middleware/auth_test.go b/middleware/auth_test.go index b2d9ef5..5361ea6 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -5,6 +5,8 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/auth" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" @@ -198,3 +200,145 @@ func TestWebDAVAuth(t *testing.T) { } } + +func TestRemoteCallbackAuth(t *testing.T) { + asserts := assert.New(t) + rec := httptest.NewRecorder() + AuthFunc := RemoteCallbackAuth() + + // 成功 + { + cache.Set( + "callback_testCallBackRemote", + serializer.UploadSession{ + UID: 1, + PolicyID: 2, + VirtualPath: "/", + }, + 0, + ) + cache.Deletes([]string{"1"}, "policy_") + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) + mock.ExpectQuery("SELECT(.+)groups(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + mock.ExpectQuery("SELECT(.+)policies(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"key", "testCallBackRemote"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) + authInstance := auth.HMACAuth{SecretKey: []byte("123")} + auth.SignRequest(authInstance, c.Request, 0) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.False(c.IsAborted()) + } + + // Callback Key 不存在 + { + + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"key", "testCallBackRemote"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) + authInstance := auth.HMACAuth{SecretKey: []byte("123")} + auth.SignRequest(authInstance, c.Request, 0) + AuthFunc(c) + asserts.True(c.IsAborted()) + } + + // 用户不存在 + { + cache.Set( + "callback_testCallBackRemote", + serializer.UploadSession{ + UID: 1, + PolicyID: 2, + VirtualPath: "/", + }, + 0, + ) + cache.Deletes([]string{"1"}, "policy_") + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"})) + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"key", "testCallBackRemote"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) + authInstance := auth.HMACAuth{SecretKey: []byte("123")} + auth.SignRequest(authInstance, c.Request, 0) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.True(c.IsAborted()) + } + + // 存储策略不一致 + { + cache.Set( + "callback_testCallBackRemote", + serializer.UploadSession{ + UID: 1, + PolicyID: 2, + VirtualPath: "/", + }, + 0, + ) + cache.Deletes([]string{"1"}, "policy_") + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) + mock.ExpectQuery("SELECT(.+)groups(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[3]")) + mock.ExpectQuery("SELECT(.+)policies(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(3, "123")) + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"key", "testCallBackRemote"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) + authInstance := auth.HMACAuth{SecretKey: []byte("123")} + auth.SignRequest(authInstance, c.Request, 0) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.True(c.IsAborted()) + } + + // 签名错误 + { + cache.Set( + "callback_testCallBackRemote", + serializer.UploadSession{ + UID: 1, + PolicyID: 2, + VirtualPath: "/", + }, + 0, + ) + cache.Deletes([]string{"1"}, "policy_") + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) + mock.ExpectQuery("SELECT(.+)groups(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[2]")) + mock.ExpectQuery("SELECT(.+)policies(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) + c, _ := gin.CreateTestContext(rec) + c.Params = []gin.Param{ + {"key", "testCallBackRemote"}, + } + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.True(c.IsAborted()) + } + + // Callback Key 为空 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote", nil) + AuthFunc(c) + asserts.True(c.IsAborted()) + } +} diff --git a/models/user.go b/models/user.go index e24ba87..ef266ed 100644 --- a/models/user.go +++ b/models/user.go @@ -177,7 +177,9 @@ func (user *User) AfterCreate(tx *gorm.DB) (err error) { // AfterFind 找到用户后的钩子 func (user *User) AfterFind() (err error) { // 解析用户设置到OptionsSerialized - err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) + if user.Options != "" { + err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) + } // 预加载存储策略 user.Policy, _ = GetPolicyByID(user.GetPolicyID()) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 1dc83b4..3aba627 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -60,9 +60,12 @@ func getSignContent(r *http.Request) (rawSignString string) { if policy, ok := r.Header["X-Policy"]; ok { rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "") } else { - body, _ := ioutil.ReadAll(r.Body) - _ = r.Body.Close() - r.Body = ioutil.NopCloser(bytes.NewReader(body)) + var body = []byte{} + if r.Body != nil { + body, _ = ioutil.ReadAll(r.Body) + _ = r.Body.Close() + r.Body = ioutil.NopCloser(bytes.NewReader(body)) + } rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body)) } return rawSignString