diff --git a/go.mod b/go.mod index 97caab2..f924906 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gin-contrib/sessions v0.0.1 github.com/gin-gonic/gin v1.4.0 github.com/go-ini/ini v1.50.0 + github.com/gomodule/redigo v2.0.0+incompatible github.com/jinzhu/gorm v1.9.11 github.com/juju/ratelimit v1.0.1 github.com/mattn/go-colorable v0.1.4 // indirect @@ -20,6 +21,7 @@ require ( github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/pkg/errors v0.8.0 github.com/qiniu/api.v7/v7 v7.4.0 + github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 github.com/smartystreets/goconvey v1.6.4 // indirect github.com/stretchr/testify v1.4.0 gopkg.in/go-playground/validator.v8 v8.18.2 diff --git a/models/setting.go b/models/setting.go index 268ceee..2a9498b 100644 --- a/models/setting.go +++ b/models/setting.go @@ -37,17 +37,16 @@ func GetSettingByName(name string) string { } // GetSettingByNames 用多个 Name 获取设置值 -// TODO 其他设置获取也使用缓存 func GetSettingByNames(names []string) map[string]string { var queryRes []Setting - res, miss := cache.GetsSettingByName(names) + res, miss := cache.GetSettings(names, "setting_") DB.Where("name IN (?)", miss).Find(&queryRes) for _, setting := range queryRes { res[setting.Name] = setting.Value } - _ = cache.SetSettings(res) + _ = cache.SetSettings(res, "setting_") return res } diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 0f0ff64..04cbeba 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -9,13 +9,18 @@ import ( var Store Driver func init() { - Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") - return - + //Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") + //return if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode { Store = NewMemoStore() } else { - Store = NewRedisStore(10, "tcp", conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB) + Store = NewRedisStore( + 10, + "tcp", + conf.RedisConfig.Server, + conf.RedisConfig.Password, + conf.RedisConfig.DB, + ) } } @@ -41,9 +46,9 @@ func Get(key string) (interface{}, bool) { return Store.Get(key) } -// GetsSettingByName 根据名称批量获取设置项缓存 -func GetsSettingByName(keys []string) (map[string]string, []string) { - raw, miss := Store.Gets(keys, "setting_") +// GetSettings 根据名称批量获取设置项缓存 +func GetSettings(keys []string, prefix string) (map[string]string, []string) { + raw, miss := Store.Gets(keys, prefix) res := make(map[string]string, len(raw)) for k, v := range raw { @@ -54,10 +59,10 @@ func GetsSettingByName(keys []string) (map[string]string, []string) { } // SetSettings 批量设置站点设置缓存 -func SetSettings(values map[string]string) error { +func SetSettings(values map[string]string, prefix string) error { var toBeSet = make(map[string]interface{}, len(values)) for key, value := range values { toBeSet[key] = interface{}(value) } - return Store.Sets(toBeSet, "setting_") + return Store.Sets(toBeSet, prefix) } diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go new file mode 100644 index 0000000..677c5ea --- /dev/null +++ b/pkg/cache/driver_test.go @@ -0,0 +1,44 @@ +package cache + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSet(t *testing.T) { + asserts := assert.New(t) + + asserts.NoError(Set("123", "321")) +} + +func TestGet(t *testing.T) { + asserts := assert.New(t) + asserts.NoError(Set("123", "321")) + + value, ok := Get("123") + asserts.True(ok) + asserts.Equal("321", value) + + value, ok = Get("not_exist") + asserts.False(ok) +} + +func TestGetSettings(t *testing.T) { + asserts := assert.New(t) + asserts.NoError(Set("test_1", "1")) + + values, missed := GetSettings([]string{"1", "2"}, "test_") + asserts.Equal(map[string]string{"1": "1"}, values) + asserts.Equal([]string{"2"}, missed) +} + +func TestSetSettings(t *testing.T) { + asserts := assert.New(t) + + err := SetSettings(map[string]string{"3": "3", "4": "4"}, "test_") + asserts.NoError(err) + value1, _ := Get("test_3") + value2, _ := Get("test_4") + asserts.Equal("3", value1) + asserts.Equal("4", value2) +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index be30744..f7fa63f 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/gob" "github.com/HFO4/cloudreve/pkg/util" - "github.com/garyburd/redigo/redis" + "github.com/gomodule/redigo/redis" "strconv" "time" ) @@ -18,6 +18,30 @@ type item struct { Value interface{} } +func serializer(value interface{}) ([]byte, error) { + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + storeValue := item{ + Value: value, + } + err := enc.Encode(storeValue) + if err != nil { + return nil, err + } + return buffer.Bytes(), nil +} + +func deserializer(value []byte) (interface{}, error) { + var res item + buffer := bytes.NewReader(value) + dec := gob.NewDecoder(buffer) + err := dec.Decode(&res) + if err != nil { + return nil, err + } + return res.Value, nil +} + // NewRedisStore 创建新的redis存储 func NewRedisStore(size int, network, address, password, database string) *RedisStore { return &RedisStore{ @@ -55,46 +79,42 @@ func (store *RedisStore) Set(key string, value interface{}) error { rc := store.pool.Get() defer rc.Close() - var buffer bytes.Buffer - enc := gob.NewEncoder(&buffer) - storeValue := item{ - Value: value, - } - err := enc.Encode(storeValue) + serialized, err := serializer(value) if err != nil { return err } - if rc.Err() == nil { - _, err := rc.Do("SET", key, buffer.Bytes()) - if err != nil { - return err - } - return nil + if rc.Err() != nil { + return rc.Err() } - return rc.Err() + _, err = rc.Do("SET", key, serialized) + if err != nil { + return err + } + return nil + } // Get 取值 func (store *RedisStore) Get(key string) (interface{}, bool) { rc := store.pool.Get() defer rc.Close() + if rc.Err() != nil { + return nil, false + } v, err := redis.Bytes(rc.Do("GET", key)) - if err != nil { + if err != nil || v == nil { return nil, false } - var res item - buffer := bytes.NewReader(v) - dec := gob.NewDecoder(buffer) - err = dec.Decode(&res) + finalValue, err := deserializer(v) if err != nil { return nil, false } - return res.Value, true + return finalValue, true } @@ -102,6 +122,9 @@ func (store *RedisStore) Get(key string) (interface{}, bool) { func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { rc := store.pool.Get() defer rc.Close() + if rc.Err() != nil { + return nil, keys + } var queryKeys = make([]string, len(keys)) for key, value := range keys { @@ -117,14 +140,11 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac var missed = make([]string, 0, len(keys)) for key, value := range v { - var decoded item - buffer := bytes.NewReader(value) - dec := gob.NewDecoder(buffer) - err = dec.Decode(&decoded) - if err != nil || decoded.Value == nil { + decoded, err := deserializer(value) + if err != nil || decoded == nil { missed = append(missed, keys[key]) } else { - res[keys[key]] = decoded.Value + res[keys[key]] = decoded } } // 解码所得值 @@ -135,20 +155,18 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error { rc := store.pool.Get() defer rc.Close() + if rc.Err() != nil { + return rc.Err() + } var setValues = make(map[string]interface{}) // 编码待设置值 for key, value := range values { - var buffer bytes.Buffer - enc := gob.NewEncoder(&buffer) - storeValue := item{ - Value: value, - } - err := enc.Encode(storeValue) + serialized, err := serializer(value) if err != nil { return err } - setValues[prefix+key] = buffer.Bytes() + setValues[prefix+key] = serialized } if rc.Err() == nil { diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go new file mode 100644 index 0000000..41c9f14 --- /dev/null +++ b/pkg/cache/redis_test.go @@ -0,0 +1,254 @@ +package cache + +import ( + "errors" + "fmt" + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewRedisStore(t *testing.T) { + asserts := assert.New(t) + + store := NewRedisStore(10, "tcp", ":2333", "", "0") + asserts.NotNil(store) +} + +func TestRedisStore_Set(t *testing.T) { + asserts := assert.New(t) + conn := redigomock.NewConn() + pool := &redis.Pool{ + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + store := &RedisStore{pool: pool} + + // 正常情况 + { + cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectStringSlice("OK") + err := store.Set("test", "test val") + asserts.NoError(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + + // 序列化出错 + { + value := struct { + Key string + }{ + Key: "123", + } + err := store.Set("test", value) + asserts.Error(err) + } + + // 命令执行失败 + { + conn.Clear() + cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectError(errors.New("error")) + err := store.Set("test", "test val") + asserts.Error(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + // 获取连接失败 + { + store.pool = &redis.Pool{ + Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, + MaxIdle: 10, + } + err := store.Set("test", "123") + asserts.Error(err) + } + +} + +func TestRedisStore_Get(t *testing.T) { + asserts := assert.New(t) + conn := redigomock.NewConn() + pool := &redis.Pool{ + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + store := &RedisStore{pool: pool} + + // 正常情况 + { + expectVal, _ := serializer("test val") + cmd := conn.Command("GET", "test").Expect(expectVal) + val, ok := store.Get("test") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.True(ok) + asserts.Equal("test val", val.(string)) + } + + // Key不存在 + { + conn.Clear() + cmd := conn.Command("GET", "test").Expect(nil) + val, ok := store.Get("test") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.False(ok) + asserts.Nil(val) + } + // 解码错误 + { + conn.Clear() + cmd := conn.Command("GET", "test").Expect([]byte{0x20}) + val, ok := store.Get("test") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.False(ok) + asserts.Nil(val) + } + // 获取连接失败 + { + store.pool = &redis.Pool{ + Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, + MaxIdle: 10, + } + val, ok := store.Get("test") + asserts.False(ok) + asserts.Nil(val) + } +} + +func TestRedisStore_Gets(t *testing.T) { + asserts := assert.New(t) + conn := redigomock.NewConn() + pool := &redis.Pool{ + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + store := &RedisStore{pool: pool} + + // 全部命中 + { + conn.Clear() + value1, _ := serializer("1") + value2, _ := serializer("2") + cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice( + value1, value2) + res, missed := store.Gets([]string{"1", "2"}, "test_") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.Len(missed, 0) + asserts.Len(res, 2) + asserts.Equal("1", res["1"].(string)) + asserts.Equal("2", res["2"].(string)) + } + + // 命中一个 + { + conn.Clear() + value2, _ := serializer("2") + cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice( + nil, value2) + res, missed := store.Gets([]string{"1", "2"}, "test_") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.Len(missed, 1) + asserts.Len(res, 1) + asserts.Equal("1", missed[0]) + asserts.Equal("2", res["2"].(string)) + } + + // 命令出错 + { + conn.Clear() + cmd := conn.Command("MGET", "test_1", "test_2").ExpectError(errors.New("error")) + res, missed := store.Gets([]string{"1", "2"}, "test_") + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + asserts.Len(missed, 2) + asserts.Len(res, 0) + } + + // 连接出错 + { + conn.Clear() + store.pool = &redis.Pool{ + Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, + MaxIdle: 10, + } + res, missed := store.Gets([]string{"1", "2"}, "test_") + asserts.Len(missed, 2) + asserts.Len(res, 0) + } +} + +func TestRedisStore_Sets(t *testing.T) { + asserts := assert.New(t) + conn := redigomock.NewConn() + pool := &redis.Pool{ + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + store := &RedisStore{pool: pool} + + // 正常 + { + cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK") + err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") + asserts.NoError(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + + // 序列化失败 + { + conn.Clear() + value := struct { + Key string + }{ + Key: "123", + } + err := store.Sets(map[string]interface{}{"1": value, "2": "2"}, "test_") + asserts.Error(err) + } + + // 执行失败 + { + cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error")) + err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") + asserts.Error(err) + if conn.Stats(cmd) != 1 { + fmt.Println("Command was not used") + return + } + } + + // 连接失败 + { + conn.Clear() + store.pool = &redis.Pool{ + Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, + MaxIdle: 10, + } + err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") + asserts.Error(err) + } +}