Merge branch 'master' of github.com:cloudreve/Cloudreve

pull/779/head
ellermister 5 years ago
commit c3c19cd754

@ -90,7 +90,7 @@ func WebDAVAuth() gin.HandlerFunc {
return return
} }
expectedUser, err := model.GetUserByEmail(username) expectedUser, err := model.GetActiveUserByEmail(username)
if err != nil { if err != nil {
c.Status(http.StatusUnauthorized) c.Status(http.StatusUnauthorized)
c.Abort() c.Abort()

@ -139,6 +139,13 @@ func GetActiveUserByOpenID(openid string) (User, error) {
// GetUserByEmail 用Email获取用户 // GetUserByEmail 用Email获取用户
func GetUserByEmail(email string) (User, error) { func GetUserByEmail(email string) (User, error) {
var user User
result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user)
return user, result.Error
}
// GetActiveUserByEmail 用Email获取可登录用户
func GetActiveUserByEmail(email string) (User, error) {
var user User var user User
result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user) result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user)
return user, result.Error return user, result.Error

@ -352,10 +352,20 @@ func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
} }
} }
func TestGetUserByEmail(t *testing.T) { func TestGetActiveUserByEmail(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
_, err := GetActiveUserByEmail("abslant@foxmail.com")
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestGetUserByEmail(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
_, err := GetUserByEmail("abslant@foxmail.com") _, err := GetUserByEmail("abslant@foxmail.com")
asserts.Error(err) asserts.Error(err)

@ -2,13 +2,6 @@ package local
import ( import (
"context" "context"
"io"
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
@ -16,6 +9,12 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
) )
func TestHandler_Put(t *testing.T) { func TestHandler_Put(t *testing.T) {
@ -61,24 +60,34 @@ func TestHandler_Delete(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
handler := Driver{} handler := Driver{}
ctx := context.Background() ctx := context.Background()
filePath := util.RelativePath("test.file")
file, err := os.Create(util.RelativePath("test.file")) file, err := os.Create(filePath)
asserts.NoError(err) asserts.NoError(err)
_ = file.Close() _ = file.Close()
list, err := handler.Delete(ctx, []string{"test.file"}) list, err := handler.Delete(ctx, []string{"test.file"})
asserts.Equal([]string{}, list) asserts.Equal([]string{}, list)
asserts.NoError(err) asserts.NoError(err)
file, err = os.Create(util.RelativePath("test.file")) file, err = os.Create(filePath)
asserts.NoError(err)
_ = file.Close() _ = file.Close()
file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0))
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"test.file", "test.notexist"}) list, err = handler.Delete(ctx, []string{"test.file", "test.notexist"})
asserts.Equal([]string{"test.notexist"}, list) file.Close()
asserts.Error(err) asserts.Equal([]string{}, list)
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"test.notexist"}) list, err = handler.Delete(ctx, []string{"test.notexist"})
asserts.Equal([]string{"test.notexist"}, list) asserts.Equal([]string{}, list)
asserts.Error(err) asserts.NoError(err)
file, err = os.Create(filePath)
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"test.file"})
_ = file.Close()
asserts.Equal([]string{}, list)
asserts.NoError(err)
} }
func TestHandler_Get(t *testing.T) { func TestHandler_Get(t *testing.T) {

@ -18,7 +18,7 @@ import (
// StartLoginAuthn 开始注册WebAuthn登录 // StartLoginAuthn 开始注册WebAuthn登录
func StartLoginAuthn(c *gin.Context) { func StartLoginAuthn(c *gin.Context) {
userName := c.Param("username") userName := c.Param("username")
expectedUser, err := model.GetUserByEmail(userName) expectedUser, err := model.GetActiveUserByEmail(userName)
if err != nil { if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotFound, "用户不存在", err)) c.JSON(200, serializer.Err(serializer.CodeNotFound, "用户不存在", err))
return return
@ -52,7 +52,7 @@ func StartLoginAuthn(c *gin.Context) {
// FinishLoginAuthn 完成注册WebAuthn登录 // FinishLoginAuthn 完成注册WebAuthn登录
func FinishLoginAuthn(c *gin.Context) { func FinishLoginAuthn(c *gin.Context) {
userName := c.Param("username") userName := c.Param("username")
expectedUser, err := model.GetUserByEmail(userName) expectedUser, err := model.GetActiveUserByEmail(userName)
if err != nil { if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err)) c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err))
return return

@ -94,6 +94,12 @@ func (service *UserResetEmailService) Reset(c *gin.Context) serializer.Response
// 查找用户 // 查找用户
if user, err := model.GetUserByEmail(service.UserName); err == nil { if user, err := model.GetUserByEmail(service.UserName); err == nil {
if user.Status == model.Baned || user.Status == model.OveruseBaned {
return serializer.Err(403, "该账号已被封禁", nil)
}
if user.Status == model.NotActivicated {
return serializer.Err(403, "该账号未激活", nil)
}
// 创建密码重设会话 // 创建密码重设会话
secret := util.RandStringRunes(32) secret := util.RandStringRunes(32)
cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600) cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600)

@ -64,10 +64,17 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
user.Status = model.NotActivicated user.Status = model.NotActivicated
} }
user.GroupID = uint(defaultGroup) user.GroupID = uint(defaultGroup)
userNotActivated := false
// 创建用户 // 创建用户
if err := model.DB.Create(&user).Error; err != nil { if err := model.DB.Create(&user).Error; err != nil {
return serializer.DBErr("此邮箱已被使用", err) //检查已存在使用者是否尚未激活
expectedUser, err := model.GetUserByEmail(service.UserName)
if expectedUser.Status == model.NotActivicated {
userNotActivated = true
user = expectedUser
} else {
return serializer.DBErr("此邮箱已被使用", err)
}
} }
// 发送激活邮件 // 发送激活邮件
@ -100,8 +107,12 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
if err := email.Send(user.Email, title, body); err != nil { if err := email.Send(user.Email, title, body); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "无法发送激活邮件", err) return serializer.Err(serializer.CodeInternalSetting, "无法发送激活邮件", err)
} }
if userNotActivated == true {
return serializer.Response{Code: 203} //原本在上面要抛出的DBErr放来这边抛出
return serializer.DBErr("用户未激活,已重新发送激活邮件", nil)
} else {
return serializer.Response{Code: 203}
}
} }
return serializer.Response{} return serializer.Response{}

Loading…
Cancel
Save