parent
31aba9b9ff
commit
9c19fd43fa
@ -0,0 +1,269 @@
|
||||
package batcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/utils/idutil"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultDataChanSize = 1000
|
||||
DefaultSize = 100
|
||||
DefaultBuffer = 100
|
||||
DefaultWorker = 5
|
||||
DefaultInterval = time.Second
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
size int // Number of message aggregations
|
||||
buffer int // The number of caches running in a single coroutine
|
||||
worker int // Number of coroutines processed in parallel
|
||||
interval time.Duration // Time of message aggregations
|
||||
syncWait bool // Whether to wait synchronously after distributing messages
|
||||
}
|
||||
|
||||
type Option func(c *Config)
|
||||
|
||||
func WithSize(s int) Option {
|
||||
return func(c *Config) {
|
||||
c.size = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithBuffer(b int) Option {
|
||||
return func(c *Config) {
|
||||
c.buffer = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorker(w int) Option {
|
||||
return func(c *Config) {
|
||||
c.worker = w
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterval(i time.Duration) Option {
|
||||
return func(c *Config) {
|
||||
c.interval = i
|
||||
}
|
||||
}
|
||||
|
||||
func WithSyncWait(wait bool) Option {
|
||||
return func(c *Config) {
|
||||
c.syncWait = wait
|
||||
}
|
||||
}
|
||||
|
||||
type Batcher[T any] struct {
|
||||
config *Config
|
||||
|
||||
globalCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
Do func(ctx context.Context, channelID int, val *Msg[T])
|
||||
OnComplete func(lastMessage *T, totalCount int)
|
||||
Sharding func(key string) int
|
||||
Key func(data *T) string
|
||||
HookFunc func(triggerID string, messages map[string][]*T, totalCount int, lastMessage *T)
|
||||
data chan *T
|
||||
chArrays []chan *Msg[T]
|
||||
wait sync.WaitGroup
|
||||
counter sync.WaitGroup
|
||||
}
|
||||
|
||||
func emptyOnComplete[T any](*T, int) {}
|
||||
func emptyHookFunc[T any](string, map[string][]*T, int, *T) {
|
||||
}
|
||||
|
||||
func New[T any](opts ...Option) *Batcher[T] {
|
||||
b := &Batcher[T]{
|
||||
OnComplete: emptyOnComplete[T],
|
||||
HookFunc: emptyHookFunc[T],
|
||||
}
|
||||
config := &Config{
|
||||
size: DefaultSize,
|
||||
buffer: DefaultBuffer,
|
||||
worker: DefaultWorker,
|
||||
interval: DefaultInterval,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
b.config = config
|
||||
b.data = make(chan *T, DefaultDataChanSize)
|
||||
b.globalCtx, b.cancel = context.WithCancel(context.Background())
|
||||
|
||||
b.chArrays = make([]chan *Msg[T], b.config.worker)
|
||||
for i := 0; i < b.config.worker; i++ {
|
||||
b.chArrays[i] = make(chan *Msg[T], b.config.buffer)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) Worker() int {
|
||||
return b.config.worker
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) Start() error {
|
||||
if b.Sharding == nil {
|
||||
return errs.New("Sharding function is required").Wrap()
|
||||
}
|
||||
if b.Do == nil {
|
||||
return errs.New("Do function is required").Wrap()
|
||||
}
|
||||
if b.Key == nil {
|
||||
return errs.New("Key function is required").Wrap()
|
||||
}
|
||||
b.wait.Add(b.config.worker)
|
||||
for i := 0; i < b.config.worker; i++ {
|
||||
go b.run(i, b.chArrays[i])
|
||||
}
|
||||
b.wait.Add(1)
|
||||
go b.scheduler()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) Put(ctx context.Context, data *T) error {
|
||||
if data == nil {
|
||||
return errs.New("data can not be nil").Wrap()
|
||||
}
|
||||
select {
|
||||
case <-b.globalCtx.Done():
|
||||
return errs.New("data channel is closed").Wrap()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case b.data <- data:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) scheduler() {
|
||||
ticker := time.NewTicker(b.config.interval)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
for _, ch := range b.chArrays {
|
||||
close(ch) // 发送关闭信号到每个worker
|
||||
}
|
||||
close(b.data)
|
||||
b.wait.Done()
|
||||
}()
|
||||
|
||||
vals := make(map[string][]*T)
|
||||
count := 0
|
||||
var lastAny *T
|
||||
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-b.data:
|
||||
if !ok {
|
||||
// 如果data channel意外关闭
|
||||
return
|
||||
}
|
||||
if data == nil {
|
||||
// 接收到nil作为结束信号
|
||||
fmt.Println("Batcher Closing1", count)
|
||||
if count > 0 {
|
||||
fmt.Println("Batcher Closing2", count)
|
||||
b.distributeMessage(vals, count, lastAny)
|
||||
}
|
||||
return
|
||||
}
|
||||
// 正常数据处理
|
||||
key := b.Key(data)
|
||||
vals[key] = append(vals[key], data)
|
||||
lastAny = data
|
||||
|
||||
count++
|
||||
if count >= b.config.size {
|
||||
|
||||
fmt.Printf("counter to %d, %v\n", count, lastAny)
|
||||
b.distributeMessage(vals, count, lastAny)
|
||||
vals = make(map[string][]*T)
|
||||
count = 0
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
if count > 0 {
|
||||
fmt.Printf("ticker to %v , %d, %v\n", b.config.interval, count, lastAny)
|
||||
b.distributeMessage(vals, count, lastAny)
|
||||
vals = make(map[string][]*T)
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Msg[T any] struct {
|
||||
key string
|
||||
triggerID string
|
||||
val []*T
|
||||
}
|
||||
|
||||
func (m Msg[T]) Key() string {
|
||||
return m.key
|
||||
}
|
||||
|
||||
func (m Msg[T]) TriggerID() string {
|
||||
return m.triggerID
|
||||
}
|
||||
|
||||
func (m Msg[T]) Val() []*T {
|
||||
return m.val
|
||||
}
|
||||
|
||||
func (m Msg[T]) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Key: ")
|
||||
sb.WriteString(m.key)
|
||||
sb.WriteString(", Values: [")
|
||||
for i, v := range m.val {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, lastMessage *T) {
|
||||
triggerID := idutil.OperationIDGenerator()
|
||||
b.HookFunc(triggerID, messages, totalCount, lastMessage)
|
||||
for key, data := range messages {
|
||||
if b.config.syncWait {
|
||||
b.counter.Add(1)
|
||||
}
|
||||
channelID := b.Sharding(key)
|
||||
b.chArrays[channelID] <- &Msg[T]{key: key, triggerID: triggerID, val: data}
|
||||
}
|
||||
if b.config.syncWait {
|
||||
b.counter.Wait()
|
||||
}
|
||||
b.OnComplete(lastMessage, totalCount)
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) {
|
||||
defer b.wait.Done()
|
||||
for {
|
||||
select {
|
||||
case messages, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.Do(context.Background(), channelID, messages)
|
||||
if b.config.syncWait {
|
||||
b.counter.Done()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batcher[T]) Close() {
|
||||
b.cancel() // Signal to stop put data
|
||||
b.data <- nil
|
||||
//wait all goroutines exit
|
||||
b.wait.Wait()
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
package batcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBatcher(t *testing.T) {
|
||||
config := Config{
|
||||
size: 1000,
|
||||
buffer: 10,
|
||||
worker: 10,
|
||||
interval: 5 * time.Millisecond,
|
||||
}
|
||||
|
||||
b := New[string](
|
||||
WithSize(config.size),
|
||||
WithBuffer(config.buffer),
|
||||
WithWorker(config.worker),
|
||||
WithInterval(config.interval),
|
||||
WithSyncWait(true),
|
||||
)
|
||||
|
||||
// Mock Do function to simply print values for demonstration
|
||||
b.Do = func(ctx context.Context, channelID int, vals *Msg[string]) {
|
||||
t.Logf("Channel %d Processed batch: %v", channelID, vals)
|
||||
}
|
||||
b.OnComplete = func(lastMessage *string, totalCount int) {
|
||||
t.Logf("Completed processing with last message: %v, total count: %d", *lastMessage, totalCount)
|
||||
}
|
||||
b.Sharding = func(key string) int {
|
||||
hashCode := stringutil.GetHashCode(key)
|
||||
return int(hashCode) % config.worker
|
||||
}
|
||||
b.Key = func(data *string) string {
|
||||
return *data
|
||||
}
|
||||
|
||||
err := b.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test normal data processing
|
||||
for i := 0; i < 10000; i++ {
|
||||
data := "data" + fmt.Sprintf("%d", i)
|
||||
if err := b.Put(context.Background(), &data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(1) * time.Second)
|
||||
start := time.Now()
|
||||
// Wait for all processing to finish
|
||||
b.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
t.Logf("Close took %s", elapsed)
|
||||
|
||||
if len(b.data) != 0 {
|
||||
t.Error("Data channel should be empty after closing")
|
||||
}
|
||||
}
|
Loading…
Reference in new issue