From 430965316050de0b64fd570abf7d0f6f46f6108e Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 14 Nov 2019 14:18:10 +0800 Subject: [PATCH] Feat: model/policy --- models/group.go | 31 ++++++++++++++------- models/group_test.go | 2 +- models/migration.go | 56 +++++++++++++++++++++++++------------ models/policy.go | 66 ++++++++++++++++++++++++++++++++++++++++++++ models/user.go | 51 ++++++++++++++++++++++++++++++---- models/user_test.go | 59 +++++++++++++++++++++++++++++++++++++-- pkg/util/common.go | 10 +++++++ 7 files changed, 239 insertions(+), 36 deletions(-) create mode 100644 models/policy.go diff --git a/models/group.go b/models/group.go index 2338c16..1d8acba 100644 --- a/models/group.go +++ b/models/group.go @@ -8,18 +8,16 @@ import ( // Group 用户组模型 type Group struct { gorm.Model - Name string - Policies string - MaxStorage uint64 - SpeedLimit int - ShareEnabled bool - RangeTransferEnabled bool - WebDAVEnabled bool - Aria2Option string - Color string + Name string + Policies string + MaxStorage uint64 + ShareEnabled bool + WebDAVEnabled bool + Aria2Option string + Color string // 数据库忽略字段 - PolicyList []int `gorm:"-"` + PolicyList []uint `gorm:"-"` } // GetGroupByID 用ID获取用户组 @@ -35,3 +33,16 @@ func (group *Group) AfterFind() (err error) { err = json.Unmarshal([]byte(group.Policies), &group.PolicyList) return err } + +// BeforeSave Save用户前的钩子 +func (group *Group) BeforeSave() (err error) { + err = group.SerializePolicyList() + return err +} + +//SerializePolicyList 将序列后的可选策略列表写入数据库字段 +func (group *Group) SerializePolicyList() (err error) { + optionsValue, err := json.Marshal(&group.PolicyList) + group.Policies = string(optionsValue) + return err +} diff --git a/models/group_test.go b/models/group_test.go index 9e28c19..1adb9e6 100644 --- a/models/group_test.go +++ b/models/group_test.go @@ -24,7 +24,7 @@ func TestGetGroupByID(t *testing.T) { }, Name: "管理员", Policies: "[1]", - PolicyList: []int{1}, + PolicyList: []uint{1}, }, group) //未找到用户时 diff --git a/models/migration.go b/models/migration.go index 401e231..5fdfb49 100644 --- a/models/migration.go +++ b/models/migration.go @@ -25,7 +25,10 @@ func migration() { util.Log().Info("开始进行数据库自动迁移...") // 自动迁移模式 - DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}, &Group{}) + DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}) + + // 创建初始存储策略 + addDefaultPolicy() // 创建初始用户组 addDefaultGroups() @@ -44,6 +47,27 @@ func migration() { } +func addDefaultPolicy() { + _, err := GetPolicyByID(1) + // 未找到初始存储策略时,则创建 + if gorm.IsRecordNotFoundError(err) { + defaultPolicy := Policy{ + Name: "默认上传策略", + Type: "local", + Server: "/Api/V3/File/Upload", + BaseURL: "http://cloudreve.org/public/uploads/", + MaxSize: 10 * 1024 * 1024 * 1024, + AutoRename: true, + DirNameRule: "{date}/{uid}", + FileNameRule: "{uid}_{randomkey8}_{originname}", + IsOriginLinkEnable: false, + } + if err := DB.Create(&defaultPolicy).Error; err != nil { + util.Log().Panic("无法创建初始存储策略, ", err) + } + } +} + func addDefaultSettings() { defaultSettings := []Setting{ {Name: "siteURL", Value: `http://lite.aoaoao.me/`, Type: "basic"}, @@ -129,14 +153,13 @@ func addDefaultGroups() { // 未找到初始管理组时,则创建 if gorm.IsRecordNotFoundError(err) { defaultAdminGroup := Group{ - Name: "管理员", - Policies: "[1]", - MaxStorage: 1 * 1024 * 1024 * 1024, - ShareEnabled: true, - Color: "danger", - RangeTransferEnabled: true, - WebDAVEnabled: true, - Aria2Option: "0,0,0", + Name: "管理员", + PolicyList: []uint{1}, + MaxStorage: 1 * 1024 * 1024 * 1024, + ShareEnabled: true, + Color: "danger", + WebDAVEnabled: true, + Aria2Option: "0,0,0", } if err := DB.Create(&defaultAdminGroup).Error; err != nil { util.Log().Panic("无法创建管理用户组, ", err) @@ -148,14 +171,13 @@ func addDefaultGroups() { // 未找到初始注册会员时,则创建 if gorm.IsRecordNotFoundError(err) { defaultAdminGroup := Group{ - Name: "注册会员", - Policies: "[1]", - MaxStorage: 1 * 1024 * 1024 * 1024, - ShareEnabled: true, - Color: "danger", - RangeTransferEnabled: true, - WebDAVEnabled: true, - Aria2Option: "0,0,0", + Name: "注册会员", + PolicyList: []uint{1}, + MaxStorage: 1 * 1024 * 1024 * 1024, + ShareEnabled: true, + Color: "danger", + WebDAVEnabled: true, + Aria2Option: "0,0,0", } if err := DB.Create(&defaultAdminGroup).Error; err != nil { util.Log().Panic("无法创建初始注册会员用户组, ", err) diff --git a/models/policy.go b/models/policy.go new file mode 100644 index 0000000..e4bdb19 --- /dev/null +++ b/models/policy.go @@ -0,0 +1,66 @@ +package model + +import ( + "encoding/json" + "github.com/jinzhu/gorm" +) + +// Policy 存储策略 +type Policy struct { + // 表字段 + gorm.Model + Name string + Type string + Server string + BucketName string + IsPrivate bool + BaseURL string + AccessKey string `gorm:"size:512"` + SecretKey string `gorm:"size:512"` + MaxSize uint64 + AutoRename bool + DirNameRule string + FileNameRule string + IsOriginLinkEnable bool + Options string `gorm:"size:4096"` + + // 数据库忽略字段 + OptionsSerialized PolicyOption `gorm:"-"` +} + +// PolicyOption 非公有的存储策略属性 +type PolicyOption struct { + OPName string `json:"op_name"` + OPPassword string `json:"op_pwd"` + FileType []string `json:"file_type"` + MimeType string `json:"mimetype"` + SpeedLimit int `json:"speed_limit"` + RangeTransferEnabled bool `json:"range_transfer_enabled"` +} + +// GetPolicyByID 用ID获取存储策略 +func GetPolicyByID(ID interface{}) (Policy, error) { + var policy Policy + result := DB.First(&policy, ID) + return policy, result.Error +} + +// AfterFind 找到上传策略后的钩子 +func (policy *Policy) AfterFind() (err error) { + // 解析上传策略设置到OptionsSerialized + err = json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized) + return err +} + +// BeforeSave Save策略前的钩子 +func (policy *Policy) BeforeSave() (err error) { + err = policy.SerializeOptions() + return err +} + +//SerializeOptions 将序列后的Option写入到数据库字段 +func (policy *Policy) SerializeOptions() (err error) { + optionsValue, err := json.Marshal(&policy.OptionsSerialized) + policy.Options = string(optionsValue) + return err +} diff --git a/models/user.go b/models/user.go index 22e47cd..054bf30 100644 --- a/models/user.go +++ b/models/user.go @@ -40,7 +40,8 @@ type User struct { Options string `json:"-",gorm:"size:4096"` // 关联模型 - Group Group + Group Group + Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` // 数据库忽略字段 OptionsSerialized UserOption `gorm:"-"` @@ -48,8 +49,31 @@ type User struct { // UserOption 用户个性化配置字段 type UserOption struct { - ProfileOn int `json:"profile_on"` - WebDAVKey string `json:"webdav_key"` + ProfileOn int `json:"profile_on"` + PreferredPolicy uint `json:"preferred_policy"` + WebDAVKey string `json:"webdav_key"` +} + +// GetPolicyID 获取用户当前的上传策略ID +func (user *User) GetPolicyID() uint { + // 用户未指定时,返回可用的第一个 + if user.OptionsSerialized.PreferredPolicy == 0 { + if len(user.Group.PolicyList) != 0 { + return user.Group.PolicyList[0] + } + return 1 + } else { + // 用户指定时,先检查是否为可用策略列表中的值 + if util.ContainsUint(user.Group.PolicyList, user.OptionsSerialized.PreferredPolicy) { + return user.OptionsSerialized.PreferredPolicy + } + // 不可用时,返回第一个 + if len(user.Group.PolicyList) != 0 { + return user.Group.PolicyList[0] + } + return 1 + + } } // GetUserByID 用ID获取用户 @@ -71,17 +95,32 @@ func NewUser() User { options := UserOption{ ProfileOn: 1, } - optionsValue, _ := json.Marshal(&options) return User{ - Avatar: "default", - Options: string(optionsValue), + Avatar: "default", + OptionsSerialized: options, } } +// BeforeSave Save用户前的钩子 +func (user *User) BeforeSave() (err error) { + err = user.SerializeOptions() + return err +} + +//SerializeOptions 将序列后的Option写入到数据库字段 +func (user *User) SerializeOptions() (err error) { + optionsValue, err := json.Marshal(&user.OptionsSerialized) + user.Options = string(optionsValue) + return err +} + // AfterFind 找到用户后的钩子 func (user *User) AfterFind() (err error) { // 解析用户设置到OptionsSerialized err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) + + // 预加载存储策略 + user.Policy, _ = GetPolicyByID(user.GetPolicyID()) return err } diff --git a/models/user_test.go b/models/user_test.go index afc7b63..a309b52 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -37,7 +37,7 @@ func TestGetUserByID(t *testing.T) { }, Name: "管理员", Policies: "[1]", - PolicyList: []int{1}, + PolicyList: []uint{1}, }, }, user) @@ -85,17 +85,72 @@ func TestNewUser(t *testing.T) { newUser := NewUser() asserts.IsType(User{}, newUser) asserts.NotEmpty(newUser.Avatar) - asserts.NotEmpty(newUser.Options) + asserts.NotEmpty(newUser.OptionsSerialized) } func TestUser_AfterFind(t *testing.T) { asserts := assert.New(t) + policyRows := sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, "默认上传策略") + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) + newUser := NewUser() err := newUser.AfterFind() + err = newUser.BeforeSave() expected := UserOption{} err = json.Unmarshal([]byte(newUser.Options), &expected) asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) asserts.Equal(expected, newUser.OptionsSerialized) + asserts.Equal("默认上传策略", newUser.Policy.Name) +} + +func TestUser_BeforeSave(t *testing.T) { + asserts := assert.New(t) + + newUser := NewUser() + err := newUser.BeforeSave() + expected, err := json.Marshal(newUser.OptionsSerialized) + + asserts.NoError(err) + asserts.Equal(string(expected), newUser.Options) +} + +func TestUser_GetPolicyID(t *testing.T) { + asserts := assert.New(t) + + newUser := NewUser() + + testCases := []struct { + preferred uint + available []uint + expected uint + }{ + { + available: []uint{1}, + expected: 1, + }, + { + available: []uint{5, 2, 3}, + expected: 5, + }, + { + preferred: 1, + available: []uint{5, 1, 3}, + expected: 1, + }, + { + preferred: 9, + available: []uint{5, 1, 3}, + expected: 5, + }, + } + + for key, testCase := range testCases { + newUser.OptionsSerialized.PreferredPolicy = testCase.preferred + newUser.Group.PolicyList = testCase.available + asserts.Equal(testCase.expected, newUser.GetPolicyID(), "测试用例 #%d 未通过", key) + } } diff --git a/pkg/util/common.go b/pkg/util/common.go index ff4c772..4f058fc 100644 --- a/pkg/util/common.go +++ b/pkg/util/common.go @@ -14,3 +14,13 @@ func RandStringRunes(n int) string { } return string(b) } + +// ContainsUint 返回list中是否包含 +func ContainsUint(s []uint, e uint) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +}