diff --git a/conf/conf.ini b/conf/conf.ini index 0663538..b0fdfeb 100644 --- a/conf/conf.ini +++ b/conf/conf.ini @@ -1,5 +1,6 @@ [System] Debug = true +SessionSecret = 23333 [Database] Type = mysql diff --git a/go.mod b/go.mod index 3b04447..0862aaa 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/DATA-DOG/go-sqlmock v1.3.3 + github.com/gin-contrib/sessions v0.0.1 github.com/gin-gonic/gin v1.4.0 github.com/go-ini/ini v1.50.0 github.com/go-playground/locales v0.13.0 // indirect diff --git a/middleware/session.go b/middleware/session.go new file mode 100644 index 0000000..2dd8416 --- /dev/null +++ b/middleware/session.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" +) + +// Session 初始化session +func Session(secret string) gin.HandlerFunc { + store := cookie.NewStore([]byte(secret)) + //Also set Secure: true if using SSL, you should though + store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) + return sessions.Sessions("cloudreve-session", store) +} diff --git a/models/migration.go b/models/migration.go index 556add9..fad5022 100644 --- a/models/migration.go +++ b/models/migration.go @@ -122,7 +122,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti } func addDefaultUser() { - _, err := GetUser(1) + _, err := GetUserByID(1) // 未找到初始用户时,则创建 if gorm.IsRecordNotFoundError(err) { diff --git a/models/setting.go b/models/setting.go index f14ef10..a3bfe71 100644 --- a/models/setting.go +++ b/models/setting.go @@ -15,6 +15,11 @@ type Setting struct { // settingCache 设置项缓存 var settingCache = make(map[string]string) +// IsTrueVal 返回设置的值是否为真 +func IsTrueVal(val string) bool { + return val == "1" || val == "true" +} + // GetSettingByName 用 Name 获取设置值 func GetSettingByName(name string) string { var setting Setting diff --git a/models/user.go b/models/user.go index 4013134..1c055c8 100644 --- a/models/user.go +++ b/models/user.go @@ -24,29 +24,37 @@ const ( // User 用户模型 type User struct { gorm.Model - Email string `gorm:"type:varchar(100);unique_index"` - Nick string `gorm:"size:50"` - Password string - Status int - Group int - PrimaryGroup int - ActivationKey string - Storage int64 - LastNotify *time.Time - OpenID string - TwoFactor string - Delay int - Avatar string - Options string `gorm:"size:4096"` + Email string `gorm:"type:varchar(100);unique_index"` + Nick string `gorm:"size:50"` + Password string + Status int + Group int + PrimaryGroup int + ActivationKey string + Storage int64 + LastNotify *time.Time + OpenID string + TwoFactor string + Delay int + Avatar string + Options string `gorm:"size:4096"` + OptionsSerialized serializer.UserOption `gorm:"-"` } -// GetUser 用ID获取用户 -func GetUser(ID interface{}) (User, error) { +// GetUserByID 用ID获取用户 +func GetUserByID(ID interface{}) (User, error) { var user User result := DB.First(&user, ID) return user, result.Error } +// GetUserByEmail 用Email获取用户 +func GetUserByEmail(email string) (User, error) { + var user User + result := DB.Where("email = ?", email).First(&user) + return user, result.Error +} + // NewUser 返回一个新的空 User func NewUser() User { options := serializer.UserOption{ @@ -59,6 +67,13 @@ func NewUser() User { } } +// AfterFind 找到用户后的钩子 +func (user *User) AfterFind() (err error) { + // 解析用户设置到OptionsSerialized + err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) + return err +} + // CheckPassword 根据明文校验密码 func (user *User) CheckPassword(password string) (bool, error) { diff --git a/models/user_test.go b/models/user_test.go index 8e5715a..8f7172b 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -1,6 +1,8 @@ package model import ( + "cloudreve/pkg/serializer" + "encoding/json" "github.com/DATA-DOG/go-sqlmock" "github.com/jinzhu/gorm" "github.com/pkg/errors" @@ -8,7 +10,7 @@ import ( "testing" ) -func TestGetUser(t *testing.T) { +func TestGetUserByID(t *testing.T) { asserts := assert.New(t) //找到用户时 @@ -17,7 +19,7 @@ func TestGetUser(t *testing.T) { mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows) - user, err := GetUser(1) + user, err := GetUserByID(1) asserts.NoError(err) asserts.Equal(User{ Model: gorm.Model{ @@ -29,7 +31,7 @@ func TestGetUser(t *testing.T) { //未找到用户时 mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - user, err = GetUser(1) + user, err = GetUserByID(1) asserts.Error(err) asserts.Equal(User{}, user) } @@ -73,3 +75,15 @@ func TestNewUser(t *testing.T) { asserts.NotEmpty(newUser.Avatar) asserts.NotEmpty(newUser.Options) } + +func TestUser_AfterFind(t *testing.T) { + asserts := assert.New(t) + + newUser := NewUser() + err := newUser.AfterFind() + expected := serializer.UserOption{} + err = json.Unmarshal([]byte(newUser.Options), &expected) + + asserts.NoError(err) + asserts.Equal(expected, newUser.OptionsSerialized) +} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index c516db4..70b69c6 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -21,7 +21,8 @@ var DatabaseConfig = &database{ // system 系统通用配置 type system struct { - Debug bool + Debug bool + SessionSecret string } var SystemConfig = &system{} diff --git a/routers/controllers/main.go b/routers/controllers/main.go index 4be49f8..8edcd22 100644 --- a/routers/controllers/main.go +++ b/routers/controllers/main.go @@ -10,7 +10,7 @@ import ( func ParamErrorMsg(filed string, tag string) string { // 未通过验证的表单域与中文对应 fieldMap := map[string]string{ - "UserName": "用户名", + "UserName": "邮箱", "Password": "密码", } // 未通过的规则与中文对应 @@ -18,6 +18,7 @@ func ParamErrorMsg(filed string, tag string) string { "required": "不能为空", "min": "太短", "max": "太长", + "email": "格式不正确", } fieldVal, findField := fieldMap[filed] tagVal, findTag := tagMap[tag] diff --git a/routers/controllers/user.go b/routers/controllers/user.go index 162a740..7f8cf84 100644 --- a/routers/controllers/user.go +++ b/routers/controllers/user.go @@ -7,7 +7,7 @@ import ( // UserLogin 用户登录 func UserLogin(c *gin.Context) { - var service service.UserLoginService + var service user.UserLoginService if err := c.ShouldBindJSON(&service); err == nil { res := service.Login(c) c.JSON(200, res) diff --git a/routers/router.go b/routers/router.go index 8cc0f3e..f52fc90 100644 --- a/routers/router.go +++ b/routers/router.go @@ -1,6 +1,8 @@ package routers import ( + "cloudreve/middleware" + "cloudreve/pkg/conf" "cloudreve/routers/controllers" "github.com/gin-gonic/gin" ) @@ -8,13 +10,16 @@ import ( func InitRouter() *gin.Engine { r := gin.Default() + // 中间件 + r.Use(middleware.Session(conf.SystemConfig.SessionSecret)) + // 顶层路由分组 v3 := r.Group("/Api/V3") { // 测试用路由 v3.GET("Ping", controllers.Ping) // 用户登录 - v3.POST("User/Login", controllers.UserLogin) + v3.POST("User/Session", controllers.UserLogin) } return r diff --git a/service/user/common.go b/service/user/common.go new file mode 100644 index 0000000..e19d86a --- /dev/null +++ b/service/user/common.go @@ -0,0 +1,22 @@ +package user + +import ( + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// SetSession 设置session +func SetSession(c *gin.Context, list map[string]interface{}) { + s := sessions.Default(c) + for key, value := range list { + s.Set(key, value) + } + s.Save() +} + +// ClearSession 清空session +func ClearSession(c *gin.Context) { + s := sessions.Default(c) + s.Clear() + s.Save() +} diff --git a/service/user/login.go b/service/user/login.go new file mode 100644 index 0000000..03e3e7d --- /dev/null +++ b/service/user/login.go @@ -0,0 +1,55 @@ +package user + +import ( + "cloudreve/models" + "cloudreve/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// UserLoginService 管理用户登录的服务 +type UserLoginService struct { + //TODO 细致调整验证规则 + UserName string `form:"userName" json:"userName" binding:"required,email"` + Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"` + CaptchaCode string `form:"captchaCode" json:"captchaCode"` +} + +// Login 用户登录函数 +func (service *UserLoginService) Login(c *gin.Context) serializer.Response { + isCaptchaRequired := model.GetSettingByName("login_captcha") + expectedUser, err := model.GetUserByEmail(service.UserName) + + if model.IsTrueVal(isCaptchaRequired) { + // TODO 验证码校验 + } + + // 一系列校验 + if err != nil { + return serializer.Err(401, "用户邮箱或密码错误", err) + } + if authOK, _ := expectedUser.CheckPassword(service.Password); !authOK { + return serializer.Err(401, "用户邮箱或密码错误", nil) + } + if expectedUser.Status == model.Baned { + return serializer.Err(403, "该账号已被封禁", nil) + } + if expectedUser.Status == model.NotActivicated { + return serializer.Err(403, "该账号未激活", nil) + } + + if expectedUser.TwoFactor != "" { + //TODO 二步验证处理 + } + + //登陆成功,清空并设置session + ClearSession(c) + SetSession(c, map[string]interface{}{ + "user_id": expectedUser.ID, + }) + + return serializer.Response{ + Code: 0, + Msg: "", + } + +} diff --git a/service/user/user_login.go b/service/user/user_login.go deleted file mode 100644 index 33ea6fa..0000000 --- a/service/user/user_login.go +++ /dev/null @@ -1,28 +0,0 @@ -package service - -import ( - "cloudreve/models" - "cloudreve/pkg/serializer" - "fmt" - "github.com/gin-gonic/gin" -) - -// UserLoginService 管理用户登录的服务 -type UserLoginService struct { - //TODO 细致调整验证规则 - UserName string `form:"userName" json:"userName" binding:"required,min=5,max=30"` - Password string `form:"Password" json:"Password" binding:"required,min=8,max=40"` - CaptchaCode string `form:"captchaCode" json:"captchaCode"` -} - -// Login 用户登录函数 -func (service *UserLoginService) Login(c *gin.Context) serializer.Response { - siteName := model.GetSettingByName("siteName") - basic := model.GetSettingByNames([]string{"siteDes", "siteKeywords"}) - fmt.Println(basic) - return serializer.Response{ - Code: 0, - Msg: siteName, - } - -}