fix: add protective measures against memory overflow.

pull/2325/head
Gordon 1 year ago
parent d6bbc749f8
commit 1cae1b0330

@ -2,6 +2,7 @@ package redis
import ( import (
"context" "context"
"fmt"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log" "github.com/openimsdk/tools/log"
@ -62,6 +63,17 @@ func callLua(ctx context.Context, rdb redis.Scripter, script *redis.Script, keys
} }
func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expire int) error { func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expire int) error {
// Check if the lengths of keys and values match
if len(keys) != len(values) {
return errs.New("keys and values length mismatch").Wrap()
}
// Ensure allocation size does not overflow
maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems)
if len(values) > maxAllowedLen-1 {
return fmt.Errorf("values length is too large, causing overflow")
}
var vals = make([]any, 0, 1+len(values)) var vals = make([]any, 0, 1+len(values))
vals = append(vals, expire) vals = append(vals, expire)
for _, v := range values { for _, v := range values {
@ -72,6 +84,17 @@ func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys [
} }
func LuaSetBatchWithIndividualExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expires []int) error { func LuaSetBatchWithIndividualExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expires []int) error {
// Check if the lengths of keys, values, and expires match
if len(keys) != len(values) || len(keys) != len(expires) {
return errs.New("keys and values length mismatch").Wrap()
}
// Ensure the allocation size does not overflow
maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems)
if len(values) > maxAllowedLen-1 {
return errs.New(fmt.Sprintf("values length %d exceeds the maximum allowed length %d", len(values), maxAllowedLen-1)).Wrap()
}
var vals = make([]any, 0, len(values)+len(expires)) var vals = make([]any, 0, len(values)+len(expires))
for _, v := range values { for _, v := range values {
vals = append(vals, v) vals = append(vals, v)

Loading…
Cancel
Save