diff --git a/go.mod b/go.mod index 7ca1d6c..fab85f9 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.12 require ( github.com/DATA-DOG/go-sqlmock v1.3.3 github.com/fatih/color v1.7.0 + github.com/garyburd/redigo v1.6.0 github.com/gin-contrib/cors v1.3.0 github.com/gin-contrib/pprof v1.2.1 github.com/gin-contrib/sessions v0.0.1 diff --git a/models/migration_test.go b/models/migration_test.go index a28e11a..f2bf414 100644 --- a/models/migration_test.go +++ b/models/migration_test.go @@ -1,16 +1,17 @@ package model import ( - "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" ) func TestMigration(t *testing.T) { asserts := assert.New(t) - gin.SetMode(gin.TestMode) + DB, _ = gorm.Open("sqlite3", ":memory:") asserts.NotPanics(func() { migration() }) + DB = mockDB } diff --git a/models/policy.go b/models/policy.go index 00e389e..5f1c46f 100644 --- a/models/policy.go +++ b/models/policy.go @@ -2,11 +2,11 @@ package model import ( "encoding/json" + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" "path/filepath" "strconv" - "sync" "time" ) @@ -42,28 +42,20 @@ type PolicyOption struct { RangeTransferEnabled bool `json:"range_transfer_enabled"` } -// 存储策略缓存,部分情况下需要频繁查询存储策略 -var policyCache = make(map[uint]Policy) -var rw sync.RWMutex - // GetPolicyByID 用ID获取存储策略 func GetPolicyByID(ID interface{}) (Policy, error) { // 尝试读取缓存 - rw.RLock() - if policy, ok := policyCache[ID.(uint)]; ok { - rw.RUnlock() - return policy, nil + cacheKey := "policy_" + strconv.Itoa(int(ID.(uint))) + if policy, ok := cache.Store.Get(cacheKey); ok { + return policy.(Policy), nil } - rw.RUnlock() var policy Policy result := DB.First(&policy, ID) // 写入缓存 if result.Error == nil { - rw.Lock() - policyCache[policy.ID] = policy - rw.Unlock() + _ = cache.Store.Set(cacheKey, policy) } return policy, result.Error diff --git a/models/policy_test.go b/models/policy_test.go index 9ec34f7..7cc3bc0 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -12,18 +12,33 @@ import ( func TestGetPolicyByID(t *testing.T) { asserts := assert.New(t) - rows := sqlmock.NewRows([]string{"name", "type", "options"}). - AddRow("默认存储策略", "local", "{\"op_name\":\"123\"}") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND \\(\\(`policies`.`id` = 1\\)\\)(.+)$").WillReturnRows(rows) - policy, err := GetPolicyByID(uint(1)) - asserts.NoError(err) - asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OPName) + // 缓存未命中 + { + rows := sqlmock.NewRows([]string{"name", "type", "options"}). + AddRow("默认存储策略", "local", "{\"op_name\":\"123\"}") + mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) + policy, err := GetPolicyByID(uint(22)) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("默认存储策略", policy.Name) + asserts.Equal("123", policy.OptionsSerialized.OPName) + + rows = sqlmock.NewRows([]string{"name", "type", "options"}) + mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) + policy, err = GetPolicyByID(uint(23)) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + } + + // 命中 + { + policy, err := GetPolicyByID(uint(22)) + asserts.NoError(err) + asserts.Equal("默认存储策略", policy.Name) + asserts.Equal("123", policy.OptionsSerialized.OPName) + + } - rows = sqlmock.NewRows([]string{"name", "type", "options"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND \\(\\(`policies`.`id` = 1\\)\\)(.+)$").WillReturnRows(rows) - policy, err = GetPolicyByID(uint(1)) - asserts.Error(err) } func TestPolicy_BeforeSave(t *testing.T) { diff --git a/models/setting_test.go b/models/setting_test.go index ca76dee..c1f0854 100644 --- a/models/setting_test.go +++ b/models/setting_test.go @@ -9,6 +9,7 @@ import ( ) var mock sqlmock.Sqlmock +var mockDB *gorm.DB // TestMain 初始化数据库Mock func TestMain(m *testing.M) { @@ -19,6 +20,7 @@ func TestMain(m *testing.M) { panic("An error was not expected when opening a stub database connection") } DB, _ = gorm.Open("mysql", db) + mockDB = DB defer db.Close() m.Run() } diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go new file mode 100644 index 0000000..eb3bd8b --- /dev/null +++ b/pkg/cache/driver.go @@ -0,0 +1,10 @@ +package cache + +// Store 缓存存储器 +var Store Driver = NewMemoStore() + +// Driver 键值缓存存储容器 +type Driver interface { + Set(key string, value interface{}) error + Get(key string) (interface{}, bool) +} diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go new file mode 100644 index 0000000..fca7a20 --- /dev/null +++ b/pkg/cache/memo.go @@ -0,0 +1,26 @@ +package cache + +import "sync" + +// MemoStore 内存存储驱动 +type MemoStore struct { + Store *sync.Map +} + +// NewMemoStore 新建内存存储 +func NewMemoStore() *MemoStore { + return &MemoStore{ + Store: &sync.Map{}, + } +} + +// Set 存储值 +func (store *MemoStore) Set(key string, value interface{}) error { + store.Store.Store(key, value) + return nil +} + +// Get 取值 +func (store *MemoStore) Get(key string) (interface{}, bool) { + return store.Store.Load(key) +} diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go new file mode 100644 index 0000000..4a59b04 --- /dev/null +++ b/pkg/cache/memo_test.go @@ -0,0 +1,61 @@ +package cache + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewMemoStore(t *testing.T) { + asserts := assert.New(t) + + store := NewMemoStore() + asserts.NotNil(store) + asserts.NotNil(store.Store) +} + +func TestMemoStore_Set(t *testing.T) { + asserts := assert.New(t) + + store := NewMemoStore() + err := store.Set("KEY", "vAL") + asserts.NoError(err) + + val, ok := store.Store.Load("KEY") + asserts.True(ok) + asserts.Equal("vAL", val) +} + +func TestMemoStore_Get(t *testing.T) { + asserts := assert.New(t) + store := NewMemoStore() + + // 正常情况 + { + _ = store.Set("string", "string_val") + val, ok := store.Get("string") + asserts.Equal("string_val", val) + asserts.True(ok) + } + + // Key不存在 + { + val, ok := store.Get("something") + asserts.Equal(nil, val) + asserts.False(ok) + } + + // 存储struct + { + type testStruct struct { + key int + } + test := testStruct{key: 233} + _ = store.Set("struct", test) + val, ok := store.Get("struct") + asserts.True(ok) + res, ok := val.(testStruct) + asserts.True(ok) + asserts.Equal(test, res) + } + +}