diff --git a/internal/tools/cron_task.go b/internal/tools/cron_task.go index 6f4803628..e22504bbb 100644 --- a/internal/tools/cron_task.go +++ b/internal/tools/cron_task.go @@ -17,13 +17,18 @@ package tools import ( "context" "fmt" - "sync" + "os" + "os/signal" + "syscall" + "time" + "github.com/redis/go-redis/v9" "github.com/robfig/cron/v3" "github.com/OpenIMSDK/tools/log" "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" ) func StartTask() error { @@ -32,23 +37,75 @@ func StartTask() error { if err != nil { return err } - msgTool.ConvertTools() - c := cron.New() - var wg sync.WaitGroup - wg.Add(1) + + msgTool.convertTools() + + rdb, err := cache.NewRedis() + if err != nil { + return err + } + + // register cron tasks + var crontab = cron.New() log.ZInfo(context.Background(), "start chatRecordsClearTime cron task", "cron config", config.Config.ChatRecordsClearTime) - _, err = c.AddFunc(config.Config.ChatRecordsClearTime, msgTool.AllConversationClearMsgAndFixSeq) + _, err = crontab.AddFunc(config.Config.ChatRecordsClearTime, cronWrapFunc(rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) if err != nil { log.ZError(context.Background(), "start allConversationClearMsgAndFixSeq cron failed", err) panic(err) } + log.ZInfo(context.Background(), "start msgDestruct cron task", "cron config", config.Config.MsgDestructTime) - _, err = c.AddFunc(config.Config.MsgDestructTime, msgTool.ConversationsDestructMsgs) + _, err = crontab.AddFunc(config.Config.MsgDestructTime, cronWrapFunc(rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) if err != nil { log.ZError(context.Background(), "start conversationsDestructMsgs cron failed", err) panic(err) } - c.Start() - wg.Wait() + + // start crontab + crontab.Start() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + <-sigs + + // stop crontab, Wait for the running task to exit. + ctx := crontab.Stop() + + select { + case <-ctx.Done(): + // graceful exit + + case <-time.After(15 * time.Second): + // forced exit on timeout + } + return nil } + +// netlock redis lock. +func netlock(rdb redis.UniversalClient, key string, ttl time.Duration) bool { + value := "used" + ok, err := rdb.SetNX(context.Background(), key, value, ttl).Result() // nolint + if err != nil { + // when err is about redis server, return true. + return false + } + + return ok +} + +func cronWrapFunc(rdb redis.UniversalClient, key string, fn func()) func() { + enableCronLocker := config.Config.EnableCronLocker + return func() { + // if don't enable cron-locker, call fn directly. + if !enableCronLocker { + fn() + return + } + + // when acquire redis lock, call fn(). + if netlock(rdb, key, 5*time.Second) { + fn() + } + } +} diff --git a/internal/tools/cron_task_test.go b/internal/tools/cron_task_test.go new file mode 100644 index 000000000..2fcfba01b --- /dev/null +++ b/internal/tools/cron_task_test.go @@ -0,0 +1,81 @@ +package tools + +import ( + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/redis/go-redis/v9" + "github.com/robfig/cron/v3" + "github.com/stretchr/testify/assert" +) + +func TestDisLock(t *testing.T) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + assert.Equal(t, true, netlock(rdb, "cron-1", 1*time.Second)) + + // if exists, get false + assert.Equal(t, false, netlock(rdb, "cron-1", 1*time.Second)) + + time.Sleep(2 * time.Second) + + // wait for key on timeout, get true + assert.Equal(t, true, netlock(rdb, "cron-1", 2*time.Second)) + + // set different key + assert.Equal(t, true, netlock(rdb, "cron-2", 2*time.Second)) +} + +func TestCronWrapFunc(t *testing.T) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + once := sync.Once{} + done := make(chan struct{}, 1) + cb := func() { + once.Do(func() { + close(done) + }) + } + + start := time.Now() + key := fmt.Sprintf("cron-%v", rand.Int31()) + crontab := cron.New(cron.WithSeconds()) + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, cb)) + crontab.Start() + <-done + + dur := time.Since(start) + assert.LessOrEqual(t, dur.Seconds(), float64(2*time.Second)) + crontab.Stop() +} + +func TestCronWrapFuncWithNetlock(t *testing.T) { + config.Config.EnableCronLocker = true + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + done := make(chan string, 10) + + crontab := cron.New(cron.WithSeconds()) + + key := fmt.Sprintf("cron-%v", rand.Int31()) + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, func() { + done <- "host1" + })) + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, func() { + done <- "host2" + })) + crontab.Start() + + time.Sleep(12 * time.Second) + // the ttl of netlock is 5s, so expected value is 2. + assert.Equal(t, len(done), 2) + + crontab.Stop() +} diff --git a/internal/tools/msg_doc_convert.go b/internal/tools/msg_doc_convert.go index aa24d385f..758625be1 100644 --- a/internal/tools/msg_doc_convert.go +++ b/internal/tools/msg_doc_convert.go @@ -22,7 +22,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" ) -func (c *MsgTool) ConvertTools() { +func (c *MsgTool) convertTools() { ctx := mcontext.NewCtx("convert") conversationIDs, err := c.conversationDatabase.GetAllConversationIDs(ctx) if err != nil { diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 5309f9913..94688b0fb 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -232,6 +232,7 @@ type configStruct struct { ChatRecordsClearTime string `yaml:"chatRecordsClearTime"` MsgDestructTime string `yaml:"msgDestructTime"` Secret string `yaml:"secret"` + EnableCronLocker bool `yaml:"enableCronLocker"` TokenPolicy struct { Expire int64 `yaml:"expire"` } `yaml:"tokenPolicy"` diff --git a/pkg/common/db/cache/init_redis.go b/pkg/common/db/cache/init_redis.go index 1a5507f89..77b38d9b7 100644 --- a/pkg/common/db/cache/init_redis.go +++ b/pkg/common/db/cache/init_redis.go @@ -28,12 +28,21 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) +var ( + // singleton pattern. + redisClient redis.UniversalClient +) + const ( maxRetry = 10 // number of retries ) // NewRedis Initialize redis connection. func NewRedis() (redis.UniversalClient, error) { + if redisClient != nil { + return redisClient, nil + } + if len(config.Config.Redis.Address) == 0 { return nil, errors.New("redis address is empty") } @@ -66,5 +75,6 @@ func NewRedis() (redis.UniversalClient, error) { return nil, fmt.Errorf("redis ping %w", err) } + redisClient = rdb return rdb, err }