You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
273 lines
5.4 KiB
273 lines
5.4 KiB
6 months ago
|
package batcher
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"github.com/openimsdk/tools/errs"
|
||
|
"github.com/openimsdk/tools/utils/idutil"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
DefaultDataChanSize = 1000
|
||
|
DefaultSize = 100
|
||
|
DefaultBuffer = 100
|
||
|
DefaultWorker = 5
|
||
|
DefaultInterval = time.Second
|
||
|
)
|
||
|
|
||
|
type Config struct {
|
||
|
size int // Number of message aggregations
|
||
|
buffer int // The number of caches running in a single coroutine
|
||
|
dataBuffer int // The size of the main data channel
|
||
|
worker int // Number of coroutines processed in parallel
|
||
|
interval time.Duration // Time of message aggregations
|
||
|
syncWait bool // Whether to wait synchronously after distributing messages have been consumed
|
||
|
}
|
||
|
|
||
|
type Option func(c *Config)
|
||
|
|
||
|
func WithSize(s int) Option {
|
||
|
return func(c *Config) {
|
||
|
c.size = s
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithBuffer(b int) Option {
|
||
|
return func(c *Config) {
|
||
|
c.buffer = b
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithWorker(w int) Option {
|
||
|
return func(c *Config) {
|
||
|
c.worker = w
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithInterval(i time.Duration) Option {
|
||
|
return func(c *Config) {
|
||
|
c.interval = i
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithSyncWait(wait bool) Option {
|
||
|
return func(c *Config) {
|
||
|
c.syncWait = wait
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithDataBuffer(size int) Option {
|
||
|
return func(c *Config) {
|
||
|
c.dataBuffer = size
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type Batcher[T any] struct {
|
||
|
config *Config
|
||
|
|
||
|
globalCtx context.Context
|
||
|
cancel context.CancelFunc
|
||
|
Do func(ctx context.Context, channelID int, val *Msg[T])
|
||
|
OnComplete func(lastMessage *T, totalCount int)
|
||
|
Sharding func(key string) int
|
||
|
Key func(data *T) string
|
||
|
HookFunc func(triggerID string, messages map[string][]*T, totalCount int, lastMessage *T)
|
||
|
data chan *T
|
||
|
chArrays []chan *Msg[T]
|
||
|
wait sync.WaitGroup
|
||
|
counter sync.WaitGroup
|
||
|
}
|
||
|
|
||
|
func emptyOnComplete[T any](*T, int) {}
|
||
|
func emptyHookFunc[T any](string, map[string][]*T, int, *T) {
|
||
|
}
|
||
|
|
||
|
func New[T any](opts ...Option) *Batcher[T] {
|
||
|
b := &Batcher[T]{
|
||
|
OnComplete: emptyOnComplete[T],
|
||
|
HookFunc: emptyHookFunc[T],
|
||
|
}
|
||
|
config := &Config{
|
||
|
size: DefaultSize,
|
||
|
buffer: DefaultBuffer,
|
||
|
worker: DefaultWorker,
|
||
|
interval: DefaultInterval,
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
opt(config)
|
||
|
}
|
||
|
b.config = config
|
||
|
b.data = make(chan *T, DefaultDataChanSize)
|
||
|
b.globalCtx, b.cancel = context.WithCancel(context.Background())
|
||
|
|
||
|
b.chArrays = make([]chan *Msg[T], b.config.worker)
|
||
|
for i := 0; i < b.config.worker; i++ {
|
||
|
b.chArrays[i] = make(chan *Msg[T], b.config.buffer)
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) Worker() int {
|
||
|
return b.config.worker
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) Start() error {
|
||
|
if b.Sharding == nil {
|
||
|
return errs.New("Sharding function is required").Wrap()
|
||
|
}
|
||
|
if b.Do == nil {
|
||
|
return errs.New("Do function is required").Wrap()
|
||
|
}
|
||
|
if b.Key == nil {
|
||
|
return errs.New("Key function is required").Wrap()
|
||
|
}
|
||
|
b.wait.Add(b.config.worker)
|
||
|
for i := 0; i < b.config.worker; i++ {
|
||
|
go b.run(i, b.chArrays[i])
|
||
|
}
|
||
|
b.wait.Add(1)
|
||
|
go b.scheduler()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) Put(ctx context.Context, data *T) error {
|
||
|
if data == nil {
|
||
|
return errs.New("data can not be nil").Wrap()
|
||
|
}
|
||
|
select {
|
||
|
case <-b.globalCtx.Done():
|
||
|
return errs.New("data channel is closed").Wrap()
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
case b.data <- data:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) scheduler() {
|
||
|
ticker := time.NewTicker(b.config.interval)
|
||
|
defer func() {
|
||
|
ticker.Stop()
|
||
|
for _, ch := range b.chArrays {
|
||
|
close(ch)
|
||
|
}
|
||
|
close(b.data)
|
||
|
b.wait.Done()
|
||
|
}()
|
||
|
|
||
|
vals := make(map[string][]*T)
|
||
|
count := 0
|
||
|
var lastAny *T
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case data, ok := <-b.data:
|
||
|
if !ok {
|
||
|
// If the data channel is closed unexpectedly
|
||
|
return
|
||
|
}
|
||
|
if data == nil {
|
||
|
if count > 0 {
|
||
|
b.distributeMessage(vals, count, lastAny)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
key := b.Key(data)
|
||
|
vals[key] = append(vals[key], data)
|
||
|
lastAny = data
|
||
|
|
||
|
count++
|
||
|
if count >= b.config.size {
|
||
|
|
||
|
b.distributeMessage(vals, count, lastAny)
|
||
|
vals = make(map[string][]*T)
|
||
|
count = 0
|
||
|
}
|
||
|
|
||
|
case <-ticker.C:
|
||
|
if count > 0 {
|
||
|
|
||
|
b.distributeMessage(vals, count, lastAny)
|
||
|
vals = make(map[string][]*T)
|
||
|
count = 0
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type Msg[T any] struct {
|
||
|
key string
|
||
|
triggerID string
|
||
|
val []*T
|
||
|
}
|
||
|
|
||
|
func (m Msg[T]) Key() string {
|
||
|
return m.key
|
||
|
}
|
||
|
|
||
|
func (m Msg[T]) TriggerID() string {
|
||
|
return m.triggerID
|
||
|
}
|
||
|
|
||
|
func (m Msg[T]) Val() []*T {
|
||
|
return m.val
|
||
|
}
|
||
|
|
||
|
func (m Msg[T]) String() string {
|
||
|
var sb strings.Builder
|
||
|
sb.WriteString("Key: ")
|
||
|
sb.WriteString(m.key)
|
||
|
sb.WriteString(", Values: [")
|
||
|
for i, v := range m.val {
|
||
|
if i > 0 {
|
||
|
sb.WriteString(", ")
|
||
|
}
|
||
|
sb.WriteString(fmt.Sprintf("%v", *v))
|
||
|
}
|
||
|
sb.WriteString("]")
|
||
|
return sb.String()
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, lastMessage *T) {
|
||
|
triggerID := idutil.OperationIDGenerator()
|
||
|
b.HookFunc(triggerID, messages, totalCount, lastMessage)
|
||
|
for key, data := range messages {
|
||
|
if b.config.syncWait {
|
||
|
b.counter.Add(1)
|
||
|
}
|
||
|
channelID := b.Sharding(key)
|
||
|
b.chArrays[channelID] <- &Msg[T]{key: key, triggerID: triggerID, val: data}
|
||
|
}
|
||
|
if b.config.syncWait {
|
||
|
b.counter.Wait()
|
||
|
}
|
||
|
b.OnComplete(lastMessage, totalCount)
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) {
|
||
|
defer b.wait.Done()
|
||
|
for {
|
||
|
select {
|
||
|
case messages, ok := <-ch:
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
b.Do(context.Background(), channelID, messages)
|
||
|
if b.config.syncWait {
|
||
|
b.counter.Done()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *Batcher[T]) Close() {
|
||
|
b.cancel() // Signal to stop put data
|
||
|
b.data <- nil
|
||
|
//wait all goroutines exit
|
||
|
b.wait.Wait()
|
||
|
}
|