diff --git a/middleware/mock.go b/middleware/mock.go new file mode 100644 index 0000000..63af1ef --- /dev/null +++ b/middleware/mock.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "cloudreve/pkg/util" + "github.com/gin-gonic/gin" +) + +// SessionMock 测试时模拟Session +var SessionMock = make(map[string]interface{}) + +// ContextMock 测试时模拟Context +var ContextMock = make(map[string]interface{}) + +// MockHelper 单元测试助手中间件 +func MockHelper() gin.HandlerFunc { + return func(c *gin.Context) { + // 将SessionMock写入会话 + util.SetSession(c, SessionMock) + for key, value := range ContextMock { + c.Set(key, value) + } + c.Next() + } +} diff --git a/routers/router.go b/routers/router.go index 89a79bf..bf6bae5 100644 --- a/routers/router.go +++ b/routers/router.go @@ -11,11 +11,21 @@ import ( func InitRouter() *gin.Engine { r := gin.Default() - // 中间件 + /* + 中间件 + */ r.Use(middleware.Session(conf.SystemConfig.SessionSecret)) + + // 测试模式加加入Mock助手中间件 + if gin.Mode() == gin.TestMode { + r.Use(middleware.MockHelper()) + } + r.Use(middleware.CurrentUser()) - // 顶层路由分组 + /* + 路由 + */ v3 := r.Group("/Api/V3") { // 测试用路由 diff --git a/routers/router_test.go b/routers/router_test.go index a96f9b9..f895eeb 100644 --- a/routers/router_test.go +++ b/routers/router_test.go @@ -2,12 +2,14 @@ package routers import ( "bytes" + "cloudreve/middleware" "cloudreve/models" "cloudreve/pkg/serializer" "database/sql" "encoding/json" "errors" "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "net/http" @@ -27,6 +29,9 @@ func TestMain(m *testing.M) { } model.DB, _ = gorm.Open("mysql", db) defer db.Close() + + // 设置gin为测试模式 + gin.SetMode(gin.TestMode) m.Run() } @@ -113,9 +118,9 @@ func TestUserSession(t *testing.T) { ) router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - expectedJson, _ := json.Marshal(testCase.expected) - asserts.JSONEq(string(expectedJson), w.Body.String()) + asserts.Equal(200, w.Code) + expectedJSON, _ := json.Marshal(testCase.expected) + asserts.JSONEq(string(expectedJSON), w.Body.String()) w.Body.Reset() asserts.NoError(mock.ExpectationsWereMet()) @@ -123,3 +128,60 @@ func TestUserSession(t *testing.T) { } } + +func TestSessionAuthCheck(t *testing.T) { + asserts := assert.New(t) + router := InitRouter() + w := httptest.NewRecorder() + + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(sqlmock.NewRows([]string{"email", "nick", "password", "options"}). + AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}")) + expectedUser, _ := model.GetUserByID(1) + + testCases := []struct { + userRows *sqlmock.Rows + sessionMock map[string]interface{} + contextMock map[string]interface{} + expected interface{} + }{ + // 未登录 + { + expected: serializer.CheckLogin(), + }, + // 登录正常 + { + userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). + AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), + sessionMock: map[string]interface{}{"user_id": 1}, + expected: serializer.BuildUserResponse(expectedUser), + }, + // UID不存在 + { + userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}), + sessionMock: map[string]interface{}{"user_id": -1}, + expected: serializer.CheckLogin(), + }, + } + + for _, testCase := range testCases { + req, _ := http.NewRequest( + "GET", + "/Api/V3/User/Me", + nil, + ) + if testCase.userRows != nil { + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows) + } + middleware.ContextMock = testCase.contextMock + middleware.SessionMock = testCase.sessionMock + router.ServeHTTP(w, req) + expectedJSON, _ := json.Marshal(testCase.expected) + + asserts.Equal(200, w.Code) + asserts.JSONEq(string(expectedJSON), w.Body.String()) + asserts.NoError(mock.ExpectationsWereMet()) + + w.Body.Reset() + } + +}