diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..0780d02 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "cloudreve/models" + "cloudreve/pkg/serializer" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// CurrentUser 获取登录用户 +func CurrentUser() gin.HandlerFunc { + return func(c *gin.Context) { + session := sessions.Default(c) + uid := session.Get("user_id") + if uid != nil { + user, err := model.GetUserByID(uid) + if err == nil { + c.Set("user", &user) + } + } + c.Next() + } +} + +// AuthRequired 需要登录 +func AuthRequired() gin.HandlerFunc { + return func(c *gin.Context) { + if user, _ := c.Get("user"); user != nil { + if _, ok := user.(*model.User); ok { + c.Next() + return + } + } + + c.JSON(200, serializer.CheckLogin()) + c.Abort() + } +} diff --git a/middleware/session.go b/middleware/session.go index 8c26aa9..14e3b3f 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -6,10 +6,13 @@ import ( "github.com/gin-gonic/gin" ) +// Store session存储 +var Store memstore.Store + // Session 初始化session func Session(secret string) gin.HandlerFunc { - store := memstore.NewStore([]byte(secret)) - //Also set Secure: true if using SSL, you should though - store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) - return sessions.Sessions("cloudreve-session", store) + Store = memstore.NewStore([]byte(secret)) + // Also set Secure: true if using SSL, you should though + Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) + return sessions.Sessions("cloudreve-session", Store) } diff --git a/middleware/session_test.go b/middleware/session_test.go new file mode 100644 index 0000000..d755b1d --- /dev/null +++ b/middleware/session_test.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSession(t *testing.T) { + asserts := assert.New(t) + + handler := Session("2333") + asserts.NotNil(handler) + asserts.NotNil(Store) + asserts.IsType(emptyFunc(), handler) +} + +func emptyFunc() gin.HandlerFunc { + return func(c *gin.Context) {} +} diff --git a/models/user.go b/models/user.go index 4f9329f..5d8a1bf 100644 --- a/models/user.go +++ b/models/user.go @@ -1,7 +1,6 @@ package model import ( - "cloudreve/pkg/serializer" "cloudreve/pkg/util" "crypto/sha1" "encoding/hex" @@ -37,8 +36,14 @@ type User struct { TwoFactor string `json:"-"` Delay int Avatar string - Options string `json:"-",gorm:"size:4096"` - OptionsSerialized serializer.UserOption `gorm:"-"` + Options string `json:"-",gorm:"size:4096"` + OptionsSerialized UserOption `gorm:"-"` +} + +// UserOption 用户个性化配置字段 +type UserOption struct { + ProfileOn int `json:"profile_on"` + WebDAVKey string `json:"webdav_key"` } // GetUserByID 用ID获取用户 @@ -57,7 +62,7 @@ func GetUserByEmail(email string) (User, error) { // NewUser 返回一个新的空 User func NewUser() User { - options := serializer.UserOption{ + options := UserOption{ ProfileOn: 1, } optionsValue, _ := json.Marshal(&options) diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go index df0c7ac..d99301c 100644 --- a/pkg/serializer/user.go +++ b/pkg/serializer/user.go @@ -1,7 +1,40 @@ package serializer -// UserOption 用户个性化配置字段 -type UserOption struct { - ProfileOn int `json:"profile_on"` - WebDAVKey string `json:"webdav_key"` +import "cloudreve/models" + +// CheckLogin 检查登录 +func CheckLogin() Response { + return Response{ + Code: CodeCheckLogin, + Msg: "未登录", + } +} + +// User 用户序列化器 +type User struct { + ID uint `json:"id"` + Email string `json:"user_name"` + Nickname string `json:"nickname"` + Status int `json:"status"` + Avatar string `json:"avatar"` + CreatedAt int64 `json:"created_at"` +} + +// BuildUser 序列化用户 +func BuildUser(user model.User) User { + return User{ + ID: user.ID, + Email: user.Email, + Nickname: user.Nick, + Status: user.Status, + Avatar: user.Avatar, + CreatedAt: user.CreatedAt.Unix(), + } +} + +// BuildUserResponse 序列化用户响应 +func BuildUserResponse(user model.User) Response { + return Response{ + Data: BuildUser(user), + } } diff --git a/routers/controllers/main.go b/routers/controllers/main.go index 8edcd22..99e6614 100644 --- a/routers/controllers/main.go +++ b/routers/controllers/main.go @@ -1,8 +1,10 @@ package controllers import ( + "cloudreve/models" "cloudreve/pkg/serializer" "encoding/json" + "github.com/gin-gonic/gin" "gopkg.in/go-playground/validator.v8" ) @@ -47,3 +49,13 @@ func ErrorResponse(err error) serializer.Response { return serializer.ParamErr("参数错误", err) } + +// CurrentUser 获取当前用户 +func CurrentUser(c *gin.Context) *model.User { + if user, _ := c.Get("user"); user != nil { + if u, ok := user.(*model.User); ok { + return u + } + } + return nil +} diff --git a/routers/controllers/user.go b/routers/controllers/user.go index 7f8cf84..ae77bf3 100644 --- a/routers/controllers/user.go +++ b/routers/controllers/user.go @@ -1,6 +1,7 @@ package controllers import ( + "cloudreve/pkg/serializer" "cloudreve/service/user" "github.com/gin-gonic/gin" ) @@ -16,3 +17,11 @@ func UserLogin(c *gin.Context) { } } + +// UserMe 获取当前登录的用户 +func UserMe(c *gin.Context) { + user := CurrentUser(c) + res := serializer.BuildUserResponse(*user) + c.JSON(200, res) + +} diff --git a/routers/router.go b/routers/router.go index f52fc90..89a79bf 100644 --- a/routers/router.go +++ b/routers/router.go @@ -7,11 +7,13 @@ import ( "github.com/gin-gonic/gin" ) +// InitRouter 初始化路由 func InitRouter() *gin.Engine { r := gin.Default() // 中间件 r.Use(middleware.Session(conf.SystemConfig.SessionSecret)) + r.Use(middleware.CurrentUser()) // 顶层路由分组 v3 := r.Group("/Api/V3") @@ -21,6 +23,19 @@ func InitRouter() *gin.Engine { // 用户登录 v3.POST("User/Session", controllers.UserLogin) + // 需要登录保护的 + auth := v3.Group("") + auth.Use(middleware.AuthRequired()) + { + // 用户类 + user := auth.Group("User") + { + // 当前登录用户信息 + user.GET("Me", controllers.UserMe) + } + + } + } return r } diff --git a/service/user/login.go b/service/user/login.go index 7c6b791..d474c09 100644 --- a/service/user/login.go +++ b/service/user/login.go @@ -46,10 +46,6 @@ func (service *UserLoginService) Login(c *gin.Context) serializer.Response { "user_id": expectedUser.ID, }) - return serializer.Response{ - Code: 0, - Data: &expectedUser, - Msg: "", - } + return serializer.BuildUserResponse(expectedUser) }