package redis import ( "context" "fmt" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/redis/go-redis/v9" ) var ( setBatchWithCommonExpireScript = redis.NewScript(` local expire = tonumber(ARGV[1]) for i, key in ipairs(KEYS) do redis.call('SET', key, ARGV[i + 1]) redis.call('EXPIRE', key, expire) end return #KEYS `) setBatchWithIndividualExpireScript = redis.NewScript(` local n = #KEYS for i = 1, n do redis.call('SET', KEYS[i], ARGV[i]) redis.call('EXPIRE', KEYS[i], ARGV[i + n]) end return n `) deleteBatchScript = redis.NewScript(` for i, key in ipairs(KEYS) do redis.call('DEL', key) end return #KEYS `) getBatchScript = redis.NewScript(` local values = {} for i, key in ipairs(KEYS) do local value = redis.call('GET', key) table.insert(values, value) end return values `) ) func callLua(ctx context.Context, rdb redis.Scripter, script *redis.Script, keys []string, args []any) (any, error) { log.ZDebug(ctx, "callLua args", "scriptHash", script.Hash(), "keys", keys, "args", args) r := script.EvalSha(ctx, rdb, keys, args) if redis.HasErrorPrefix(r.Err(), "NOSCRIPT") { if err := script.Load(ctx, rdb).Err(); err != nil { r = script.Eval(ctx, rdb, keys, args) } else { r = script.EvalSha(ctx, rdb, keys, args) } } v, err := r.Result() if err == redis.Nil { err = nil } return v, errs.WrapMsg(err, "call lua err", "scriptHash", script.Hash(), "keys", keys, "args", args) } func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expire int) error { // Check if the lengths of keys and values match if len(keys) != len(values) { return errs.New("keys and values length mismatch").Wrap() } // Ensure allocation size does not overflow maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems) if len(values) > maxAllowedLen-1 { return fmt.Errorf("values length is too large, causing overflow") } var vals = make([]any, 0, 1+len(values)) vals = append(vals, expire) for _, v := range values { vals = append(vals, v) } _, err := callLua(ctx, rdb, setBatchWithCommonExpireScript, keys, vals) return err } func LuaSetBatchWithIndividualExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expires []int) error { // Check if the lengths of keys, values, and expires match if len(keys) != len(values) || len(keys) != len(expires) { return errs.New("keys and values length mismatch").Wrap() } // Ensure the allocation size does not overflow maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems) if len(values) > maxAllowedLen-1 { return errs.New(fmt.Sprintf("values length %d exceeds the maximum allowed length %d", len(values), maxAllowedLen-1)).Wrap() } var vals = make([]any, 0, len(values)+len(expires)) for _, v := range values { vals = append(vals, v) } for _, ex := range expires { vals = append(vals, ex) } _, err := callLua(ctx, rdb, setBatchWithIndividualExpireScript, keys, vals) return err } func LuaDeleteBatch(ctx context.Context, rdb redis.Scripter, keys []string) error { _, err := callLua(ctx, rdb, deleteBatchScript, keys, nil) return err } func LuaGetBatch(ctx context.Context, rdb redis.Scripter, keys []string) ([]any, error) { v, err := callLua(ctx, rdb, getBatchScript, keys, nil) if err != nil { return nil, err } values, ok := v.([]any) if !ok { return nil, servererrs.ErrArgs.WrapMsg("invalid lua get batch result") } return values, nil }