diff --git a/bootstrap/init.go b/bootstrap/init.go index 4ee57d5..e5f2800 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -15,6 +15,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/wopi" "github.com/gin-gonic/gin" "io/fs" + "path/filepath" ) // Init 初始化启动 @@ -60,6 +61,12 @@ func Init(path string, statics fs.FS) { model.Init() }, }, + { + "both", + func() { + cache.Restore(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)) + }, + }, { "both", func() { diff --git a/go.mod b/go.mod index b613e35..a9b3cdb 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,8 @@ require ( github.com/gofrs/uuid v4.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 + github.com/gorilla/securecookie v1.1.1 + github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.3.0 github.com/jinzhu/gorm v1.9.11 @@ -31,6 +33,7 @@ require ( github.com/qiniu/go-sdk/v7 v7.11.1 github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 github.com/robfig/cron/v3 v3.0.1 + github.com/samber/lo v1.38.1 github.com/speps/go-hashids v2.0.0+incompatible github.com/stretchr/testify v1.7.2 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 @@ -48,7 +51,6 @@ require ( github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bgentry/speakeasy v0.1.0 // indirect - github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.1.1 // indirect @@ -83,8 +85,6 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/context v1.1.1 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/sessions v1.2.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect @@ -115,11 +115,9 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.24.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect - github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/samber/lo v1.38.1 // indirect github.com/satori/go.uuid v1.2.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/soheilhy/cmux v0.1.5 // indirect diff --git a/go.sum b/go.sum index 1421f3d..5071e89 100644 --- a/go.sum +++ b/go.sum @@ -131,8 +131,6 @@ github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb/go.mod h1:PkYb9DJNAwrSvRx5DYA+gUcOIgTGVMNkfSCbZM8cWpI= -github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04= -github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/caarlos0/ctrlc v1.0.0/go.mod h1:CdXpj4rmq0q/1Eb44M9zi2nKB0QraNKuRGYGrrHhcQw= @@ -438,7 +436,6 @@ github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2z github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= @@ -750,8 +747,6 @@ github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdk github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs= github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w= github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs= github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= @@ -1020,8 +1015,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1257,8 +1252,8 @@ golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI= golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index a33e9b3..5e214e1 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,13 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" "github.com/cloudreve/Cloudreve/v3/bootstrap" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/routers" @@ -67,20 +69,10 @@ func main() { // 收到信号后关闭服务器 sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go func() { - sig := <-sigChan - util.Log().Info("Signal %s received, shutting down server...", sig) - ctx := context.Background() - if conf.SystemConfig.GracePeriod != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second) - defer cancel() - } + go shutdown(sigChan, server) - err := server.Shutdown(ctx) - if err != nil { - util.Log().Error("Failed to shutdown server: %s", err) - } + defer func() { + <-sigChan }() // 如果启用了SSL @@ -140,3 +132,27 @@ func RunUnix(server *http.Server) error { return server.Serve(listener) } + +func shutdown(sigChan chan os.Signal, server *http.Server) { + sig := <-sigChan + util.Log().Info("Signal %s received, shutting down server...", sig) + ctx := context.Background() + if conf.SystemConfig.GracePeriod != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second) + defer cancel() + } + + // Shutdown http server + err := server.Shutdown(ctx) + if err != nil { + util.Log().Error("Failed to shutdown server: %s", err) + } + + // Persist in-memory cache + if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil { + util.Log().Warning("Failed to persist cache: %s", err) + } + + close(sigChan) +} diff --git a/middleware/session.go b/middleware/session.go index 77825ae..db90755 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -1,6 +1,8 @@ package middleware import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore" "net/http" "strings" @@ -8,28 +10,16 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-contrib/sessions/redis" "github.com/gin-gonic/gin" ) // Store session存储 -var Store memstore.Store +var Store sessions.Store // Session 初始化session func Session(secret string) gin.HandlerFunc { // Redis设置不为空,且非测试模式时使用Redis - if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode { - var err error - Store, err = redis.NewStoreWithDB(10, conf.RedisConfig.Network, conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB, []byte(secret)) - if err != nil { - util.Log().Panic("Failed to connect to Redis:%s", err) - } - - util.Log().Info("Connect to Redis server %q.", conf.RedisConfig.Server) - } else { - Store = memstore.NewStore([]byte(secret)) - } + Store = sessionstore.NewStore(cache.Store, []byte(secret)) sameSiteMode := http.SameSiteDefaultMode switch strings.ToLower(conf.CORSConfig.SameSite) { diff --git a/middleware/session_test.go b/middleware/session_test.go index ac9403c..9fbe0d2 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -5,7 +5,6 @@ import ( "net/http/httptest" "testing" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" @@ -20,14 +19,6 @@ func TestSession(t *testing.T) { asserts.NotNil(Store) asserts.IsType(emptyFunc(), handler) } - { - conf.RedisConfig.Server = "123" - asserts.Panics(func() { - Session("2333") - }) - conf.RedisConfig.Server = "" - } - } func emptyFunc() gin.HandlerFunc { diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 046d271..bf49408 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -1,11 +1,16 @@ package cache import ( + "encoding/gob" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) +func init() { + gob.Register(map[string]itemWithTTL{}) +} + // Store 缓存存储器 var Store Driver = NewMemoStore() @@ -22,6 +27,13 @@ func Init() { } } +// Restore restores cache from given disk file +func Restore(persistFile string) { + if err := Store.Restore(persistFile); err != nil { + util.Log().Warning("Failed to restore cache from disk: %s", err) + } +} + func InitSlaveOverwrites() { err := Store.Sets(conf.OptionOverwrite, "setting_") if err != nil { @@ -45,6 +57,12 @@ type Driver interface { // 删除值 Delete(keys []string, prefix string) error + + // Save in-memory cache to disk + Persist(path string) error + + // Restore cache from disk + Restore(path string) error } // Set 设置缓存值 diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go index 0c55ba5..af180d6 100644 --- a/pkg/cache/memo.go +++ b/pkg/cache/memo.go @@ -1,6 +1,9 @@ package cache import ( + "encoding/gob" + "fmt" + "os" "sync" "time" @@ -14,18 +17,20 @@ type MemoStore struct { // item 存储的对象 type itemWithTTL struct { - expires int64 - value interface{} + Expires int64 + Value interface{} } +const DefaultCacheFile = "cache_persist.bin" + func newItem(value interface{}, expires int) itemWithTTL { expires64 := int64(expires) if expires > 0 { expires64 = time.Now().Unix() + expires64 } return itemWithTTL{ - value: value, - expires: expires64, + Value: value, + Expires: expires64, } } @@ -40,11 +45,11 @@ func getValue(item interface{}, ok bool) (interface{}, bool) { return item, true } - if itemObj.expires > 0 && itemObj.expires < time.Now().Unix() { + if itemObj.Expires > 0 && itemObj.Expires < time.Now().Unix() { return nil, false } - return itemObj.value, ok + return itemObj.Value, ok } @@ -52,7 +57,7 @@ func getValue(item interface{}, ok bool) (interface{}, bool) { func (store *MemoStore) GarbageCollect() { store.Store.Range(func(key, value interface{}) bool { if item, ok := value.(itemWithTTL); ok { - if item.expires > 0 && item.expires < time.Now().Unix() { + if item.Expires > 0 && item.Expires < time.Now().Unix() { util.Log().Debug("Cache %q is garbage collected.", key.(string)) store.Store.Delete(key) } @@ -98,7 +103,7 @@ func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface // Sets 批量设置值 func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error { for key, value := range values { - store.Store.Store(prefix+key, value) + store.Store.Store(prefix+key, newItem(value, 0)) } return nil } @@ -110,3 +115,61 @@ func (store *MemoStore) Delete(keys []string, prefix string) error { } return nil } + +// Persist write memory store into cache +func (store *MemoStore) Persist(path string) error { + persisted := make(map[string]itemWithTTL) + store.Store.Range(func(key, value interface{}) bool { + v, ok := store.Store.Load(key) + if _, ok := getValue(v, ok); ok { + persisted[key.(string)] = v.(itemWithTTL) + } + + return true + }) + + res, err := serializer(persisted) + if err != nil { + return fmt.Errorf("failed to serialize cache: %s", err) + } + + err = os.WriteFile(path, res, 0644) + return err +} + +// Restore memory cache from disk file +func (store *MemoStore) Restore(path string) error { + if !util.Exists(path) { + return nil + } + + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("failed to read cache file: %s", err) + } + + defer func() { + f.Close() + os.Remove(path) + }() + + persisted := &item{} + dec := gob.NewDecoder(f) + if err := dec.Decode(&persisted); err != nil { + return fmt.Errorf("unknown cache file format: %s", err) + } + + items := persisted.Value.(map[string]itemWithTTL) + loaded := 0 + for k, v := range items { + if _, ok := getValue(v, true); ok { + loaded++ + store.Store.Store(k, v) + } else { + util.Log().Debug("Persisted cache %q is expired.", k) + } + } + + util.Log().Info("Restored %d items from %q into memory cache.", loaded, path) + return nil +} diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go index 0765e64..be90577 100644 --- a/pkg/cache/memo_test.go +++ b/pkg/cache/memo_test.go @@ -2,6 +2,7 @@ package cache import ( "github.com/stretchr/testify/assert" + "path/filepath" "testing" "time" ) @@ -23,7 +24,7 @@ func TestMemoStore_Set(t *testing.T) { val, ok := store.Store.Load("KEY") asserts.True(ok) - asserts.Equal("vAL", val.(itemWithTTL).value) + asserts.Equal("vAL", val.(itemWithTTL).Value) } func TestMemoStore_Get(t *testing.T) { @@ -145,3 +146,46 @@ func TestMemoStore_GarbageCollect(t *testing.T) { _, ok := store.Get("test") asserts.False(ok) } + +func TestMemoStore_PersistFailed(t *testing.T) { + a := assert.New(t) + store := NewMemoStore() + type testStruct struct{ v string } + store.Set("test", 1, 0) + store.Set("test2", testStruct{v: "test"}, 0) + err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")) + a.Error(err) +} + +func TestMemoStore_PersistAndRestore(t *testing.T) { + a := assert.New(t) + store := NewMemoStore() + store.Set("test", 1, 0) + // already expired + store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1}) + // expired after persist + store.Set("test3", 1, 1) + temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed") + + // Persist + err := store.Persist(temp) + a.NoError(err) + a.FileExists(temp) + + time.Sleep(2 * time.Second) + // Restore + store2 := NewMemoStore() + err = store2.Restore(temp) + a.NoError(err) + test, testOk := store2.Get("test") + a.EqualValues(1, test) + a.True(testOk) + test2, test2Ok := store2.Get("test2") + a.Nil(test2) + a.False(test2Ok) + test3, test3Ok := store2.Get("test3") + a.Nil(test3) + a.False(test3Ok) + + a.NoFileExists(temp) +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index b9f7254..08bf11e 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -215,3 +215,13 @@ func (store *RedisStore) DeleteAll() error { return err } + +// Persist Dummy implementation +func (store *RedisStore) Persist(path string) error { + return nil +} + +// Restore dummy implementation +func (store *RedisStore) Restore(path string) error { + return nil +} diff --git a/pkg/sessionstore/kv.go b/pkg/sessionstore/kv.go new file mode 100644 index 0000000..193d5c6 --- /dev/null +++ b/pkg/sessionstore/kv.go @@ -0,0 +1,136 @@ +package sessionstore + +import ( + "bytes" + "encoding/base32" + "encoding/gob" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "net/http" + "strings" +) + +type kvStore struct { + Codecs []securecookie.Codec + Options *sessions.Options + DefaultMaxAge int + + prefix string + serializer SessionSerializer + store cache.Driver +} + +func newKvStore(prefix string, store cache.Driver, keyPairs ...[]byte) *kvStore { + return &kvStore{ + prefix: prefix, + store: store, + DefaultMaxAge: 60 * 20, + serializer: GobSerializer{}, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{ + Path: "/", + MaxAge: 86400 * 30, + }, + } +} + +// Get returns a session for the given name after adding it to the registry. +// +// It returns a new session if the sessions doesn't exist. Access IsNew on +// the session to check if it is an existing session or a new one. +// +// It returns a new session and an error if the session exists but could +// not be decoded. +func (s *kvStore) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(s, name) +} + +// New returns a session for the given name without adding it to the registry. +// +// The difference between New() and Get() is that calling New() twice will +// decode the session data twice, while Get() registers and reuses the same +// decoded session after the first call. +func (s *kvStore) New(r *http.Request, name string) (*sessions.Session, error) { + var ( + err error + ) + session := sessions.NewSession(s, name) + // make a copy + options := *s.Options + session.Options = &options + session.IsNew = true + if c, errCookie := r.Cookie(name); errCookie == nil { + err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) + if err == nil { + res, ok := s.store.Get(s.prefix + session.ID) + if ok { + err = s.serializer.Deserialize(res.([]byte), session) + } + + session.IsNew = !(err == nil && ok) // not new if no error and data available + } + } + return session, err +} +func (s *kvStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + // Marked for deletion. + if session.Options.MaxAge <= 0 { + if err := s.store.Delete([]string{session.ID}, s.prefix); err != nil { + return err + } + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + } else { + // Build an alphanumeric key for the redis store. + if session.ID == "" { + session.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=") + } + + b, err := s.serializer.Serialize(session) + if err != nil { + return err + } + + age := session.Options.MaxAge + if age == 0 { + age = s.DefaultMaxAge + } + + if err := s.store.Set(s.prefix+session.ID, b, age); err != nil { + return err + } + + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) + if err != nil { + return err + } + http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) + } + return nil +} + +// SessionSerializer provides an interface hook for alternative serializers +type SessionSerializer interface { + Deserialize(d []byte, ss *sessions.Session) error + Serialize(ss *sessions.Session) ([]byte, error) +} + +// GobSerializer uses gob package to encode the session map +type GobSerializer struct{} + +// Serialize using gob +func (s GobSerializer) Serialize(ss *sessions.Session) ([]byte, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(ss.Values) + if err == nil { + return buf.Bytes(), nil + } + return nil, err +} + +// Deserialize back to map[interface{}]interface{} +func (s GobSerializer) Deserialize(d []byte, ss *sessions.Session) error { + dec := gob.NewDecoder(bytes.NewBuffer(d)) + return dec.Decode(&ss.Values) +} diff --git a/pkg/sessionstore/sessionstore.go b/pkg/sessionstore/sessionstore.go new file mode 100644 index 0000000..3b1c302 --- /dev/null +++ b/pkg/sessionstore/sessionstore.go @@ -0,0 +1,22 @@ +package sessionstore + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/gin-contrib/sessions" +) + +type Store interface { + sessions.Store +} + +func NewStore(driver cache.Driver, keyPairs ...[]byte) Store { + return &store{newKvStore("cd_session_", driver, keyPairs...)} +} + +type store struct { + *kvStore +} + +func (c *store) Options(options sessions.Options) { + c.kvStore.Options = options.ToGorillaOptions() +}