Feat: model/policy

pull/247/head
HFO4 5 years ago
parent 41e0dec74c
commit 4309653160

@ -8,18 +8,16 @@ import (
// Group 用户组模型
type Group struct {
gorm.Model
Name string
Policies string
MaxStorage uint64
SpeedLimit int
ShareEnabled bool
RangeTransferEnabled bool
WebDAVEnabled bool
Aria2Option string
Color string
Name string
Policies string
MaxStorage uint64
ShareEnabled bool
WebDAVEnabled bool
Aria2Option string
Color string
// 数据库忽略字段
PolicyList []int `gorm:"-"`
PolicyList []uint `gorm:"-"`
}
// GetGroupByID 用ID获取用户组
@ -35,3 +33,16 @@ func (group *Group) AfterFind() (err error) {
err = json.Unmarshal([]byte(group.Policies), &group.PolicyList)
return err
}
// BeforeSave Save用户前的钩子
func (group *Group) BeforeSave() (err error) {
err = group.SerializePolicyList()
return err
}
//SerializePolicyList 将序列后的可选策略列表写入数据库字段
func (group *Group) SerializePolicyList() (err error) {
optionsValue, err := json.Marshal(&group.PolicyList)
group.Policies = string(optionsValue)
return err
}

@ -24,7 +24,7 @@ func TestGetGroupByID(t *testing.T) {
},
Name: "管理员",
Policies: "[1]",
PolicyList: []int{1},
PolicyList: []uint{1},
}, group)
//未找到用户时

@ -25,7 +25,10 @@ func migration() {
util.Log().Info("开始进行数据库自动迁移...")
// 自动迁移模式
DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}, &Group{})
DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{})
// 创建初始存储策略
addDefaultPolicy()
// 创建初始用户组
addDefaultGroups()
@ -44,6 +47,27 @@ func migration() {
}
func addDefaultPolicy() {
_, err := GetPolicyByID(1)
// 未找到初始存储策略时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultPolicy := Policy{
Name: "默认上传策略",
Type: "local",
Server: "/Api/V3/File/Upload",
BaseURL: "http://cloudreve.org/public/uploads/",
MaxSize: 10 * 1024 * 1024 * 1024,
AutoRename: true,
DirNameRule: "{date}/{uid}",
FileNameRule: "{uid}_{randomkey8}_{originname}",
IsOriginLinkEnable: false,
}
if err := DB.Create(&defaultPolicy).Error; err != nil {
util.Log().Panic("无法创建初始存储策略, ", err)
}
}
}
func addDefaultSettings() {
defaultSettings := []Setting{
{Name: "siteURL", Value: `http://lite.aoaoao.me/`, Type: "basic"},
@ -129,14 +153,13 @@ func addDefaultGroups() {
// 未找到初始管理组时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Group{
Name: "管理员",
Policies: "[1]",
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
Color: "danger",
RangeTransferEnabled: true,
WebDAVEnabled: true,
Aria2Option: "0,0,0",
Name: "管理员",
PolicyList: []uint{1},
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
Color: "danger",
WebDAVEnabled: true,
Aria2Option: "0,0,0",
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("无法创建管理用户组, ", err)
@ -148,14 +171,13 @@ func addDefaultGroups() {
// 未找到初始注册会员时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Group{
Name: "注册会员",
Policies: "[1]",
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
Color: "danger",
RangeTransferEnabled: true,
WebDAVEnabled: true,
Aria2Option: "0,0,0",
Name: "注册会员",
PolicyList: []uint{1},
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
Color: "danger",
WebDAVEnabled: true,
Aria2Option: "0,0,0",
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("无法创建初始注册会员用户组, ", err)

@ -0,0 +1,66 @@
package model
import (
"encoding/json"
"github.com/jinzhu/gorm"
)
// Policy 存储策略
type Policy struct {
// 表字段
gorm.Model
Name string
Type string
Server string
BucketName string
IsPrivate bool
BaseURL string
AccessKey string `gorm:"size:512"`
SecretKey string `gorm:"size:512"`
MaxSize uint64
AutoRename bool
DirNameRule string
FileNameRule string
IsOriginLinkEnable bool
Options string `gorm:"size:4096"`
// 数据库忽略字段
OptionsSerialized PolicyOption `gorm:"-"`
}
// PolicyOption 非公有的存储策略属性
type PolicyOption struct {
OPName string `json:"op_name"`
OPPassword string `json:"op_pwd"`
FileType []string `json:"file_type"`
MimeType string `json:"mimetype"`
SpeedLimit int `json:"speed_limit"`
RangeTransferEnabled bool `json:"range_transfer_enabled"`
}
// GetPolicyByID 用ID获取存储策略
func GetPolicyByID(ID interface{}) (Policy, error) {
var policy Policy
result := DB.First(&policy, ID)
return policy, result.Error
}
// AfterFind 找到上传策略后的钩子
func (policy *Policy) AfterFind() (err error) {
// 解析上传策略设置到OptionsSerialized
err = json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized)
return err
}
// BeforeSave Save策略前的钩子
func (policy *Policy) BeforeSave() (err error) {
err = policy.SerializeOptions()
return err
}
//SerializeOptions 将序列后的Option写入到数据库字段
func (policy *Policy) SerializeOptions() (err error) {
optionsValue, err := json.Marshal(&policy.OptionsSerialized)
policy.Options = string(optionsValue)
return err
}

@ -40,7 +40,8 @@ type User struct {
Options string `json:"-",gorm:"size:4096"`
// 关联模型
Group Group
Group Group
Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
// 数据库忽略字段
OptionsSerialized UserOption `gorm:"-"`
@ -48,8 +49,31 @@ type User struct {
// UserOption 用户个性化配置字段
type UserOption struct {
ProfileOn int `json:"profile_on"`
WebDAVKey string `json:"webdav_key"`
ProfileOn int `json:"profile_on"`
PreferredPolicy uint `json:"preferred_policy"`
WebDAVKey string `json:"webdav_key"`
}
// GetPolicyID 获取用户当前的上传策略ID
func (user *User) GetPolicyID() uint {
// 用户未指定时,返回可用的第一个
if user.OptionsSerialized.PreferredPolicy == 0 {
if len(user.Group.PolicyList) != 0 {
return user.Group.PolicyList[0]
}
return 1
} else {
// 用户指定时,先检查是否为可用策略列表中的值
if util.ContainsUint(user.Group.PolicyList, user.OptionsSerialized.PreferredPolicy) {
return user.OptionsSerialized.PreferredPolicy
}
// 不可用时,返回第一个
if len(user.Group.PolicyList) != 0 {
return user.Group.PolicyList[0]
}
return 1
}
}
// GetUserByID 用ID获取用户
@ -71,17 +95,32 @@ func NewUser() User {
options := UserOption{
ProfileOn: 1,
}
optionsValue, _ := json.Marshal(&options)
return User{
Avatar: "default",
Options: string(optionsValue),
Avatar: "default",
OptionsSerialized: options,
}
}
// BeforeSave Save用户前的钩子
func (user *User) BeforeSave() (err error) {
err = user.SerializeOptions()
return err
}
//SerializeOptions 将序列后的Option写入到数据库字段
func (user *User) SerializeOptions() (err error) {
optionsValue, err := json.Marshal(&user.OptionsSerialized)
user.Options = string(optionsValue)
return err
}
// AfterFind 找到用户后的钩子
func (user *User) AfterFind() (err error) {
// 解析用户设置到OptionsSerialized
err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized)
// 预加载存储策略
user.Policy, _ = GetPolicyByID(user.GetPolicyID())
return err
}

@ -37,7 +37,7 @@ func TestGetUserByID(t *testing.T) {
},
Name: "管理员",
Policies: "[1]",
PolicyList: []int{1},
PolicyList: []uint{1},
},
}, user)
@ -85,17 +85,72 @@ func TestNewUser(t *testing.T) {
newUser := NewUser()
asserts.IsType(User{}, newUser)
asserts.NotEmpty(newUser.Avatar)
asserts.NotEmpty(newUser.Options)
asserts.NotEmpty(newUser.OptionsSerialized)
}
func TestUser_AfterFind(t *testing.T) {
asserts := assert.New(t)
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "默认上传策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
newUser := NewUser()
err := newUser.AfterFind()
err = newUser.BeforeSave()
expected := UserOption{}
err = json.Unmarshal([]byte(newUser.Options), &expected)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(expected, newUser.OptionsSerialized)
asserts.Equal("默认上传策略", newUser.Policy.Name)
}
func TestUser_BeforeSave(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
err := newUser.BeforeSave()
expected, err := json.Marshal(newUser.OptionsSerialized)
asserts.NoError(err)
asserts.Equal(string(expected), newUser.Options)
}
func TestUser_GetPolicyID(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
testCases := []struct {
preferred uint
available []uint
expected uint
}{
{
available: []uint{1},
expected: 1,
},
{
available: []uint{5, 2, 3},
expected: 5,
},
{
preferred: 1,
available: []uint{5, 1, 3},
expected: 1,
},
{
preferred: 9,
available: []uint{5, 1, 3},
expected: 5,
},
}
for key, testCase := range testCases {
newUser.OptionsSerialized.PreferredPolicy = testCase.preferred
newUser.Group.PolicyList = testCase.available
asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key)
}
}

@ -14,3 +14,13 @@ func RandStringRunes(n int) string {
}
return string(b)
}
// ContainsUint 返回list中是否包含
func ContainsUint(s []uint, e uint) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}

Loading…
Cancel
Save