From f083d52e171c4991ce4020cffc80c38480be06c3 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 9 Jun 2022 16:11:36 +0800 Subject: [PATCH] feat: tps limit for OneDrive policy --- assets | 2 +- models/policy.go | 4 +++ models/scripts/storage_test.go | 3 ++ pkg/filesystem/driver/onedrive/api.go | 5 +++ pkg/request/options.go | 34 ++++++++++++++------ pkg/request/request.go | 12 +++++-- pkg/request/request_test.go | 38 ++++++++++++++++++++++ pkg/request/tpslimiter.go | 39 +++++++++++++++++++++++ pkg/request/tpslimiter_test.go | 46 +++++++++++++++++++++++++++ 9 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 pkg/request/tpslimiter.go create mode 100644 pkg/request/tpslimiter_test.go diff --git a/assets b/assets index c0f8a7e..41f585a 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit c0f8a7ef6ddd335b697347dce56271c3d3d8c215 +Subproject commit 41f585a6f8c8f99ed4b2e279555d6b4dcdf957bc diff --git a/models/policy.go b/models/policy.go index 1338a66..efe36e4 100644 --- a/models/policy.go +++ b/models/policy.go @@ -61,6 +61,10 @@ type PolicyOption struct { ChunkSize uint64 `json:"chunk_size,omitempty"` // 分片上传时是否需要预留空间 PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"` + // 每秒对存储端的 API 请求上限 + TPSLimit float64 `json:"tps_limit,omitempty"` + // 每秒 API 请求爆发上限 + TPSLimitBurst int `json:"tps_limit_burst,omitempty"` } // thumbSuffix 支持缩略图处理的文件扩展名 diff --git a/models/scripts/storage_test.go b/models/scripts/storage_test.go index da50e5b..746f0c0 100644 --- a/models/scripts/storage_test.go +++ b/models/scripts/storage_test.go @@ -52,6 +52,9 @@ func TestUserStorageCalibration_Run(t *testing.T) { mock.ExpectQuery("SELECT(.+)files(.+)"). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10)) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() script.Run(context.Background()) asserts.NoError(mock.ExpectationsWereMet()) } diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 7ff5711..438459a 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -544,6 +544,11 @@ func (client *Client) request(ctx context.Context, method string, url string, bo "Content-Type": {"application/json"}, }), request.WithContext(ctx), + request.WithTPSLimit( + fmt.Sprintf("policy_%d", client.Policy.ID), + client.Policy.OptionsSerialized.TPSLimit, + client.Policy.OptionsSerialized.TPSLimitBurst, + ), ) // 发送请求 diff --git a/pkg/request/options.go b/pkg/request/options.go index bb9b11c..dc0391e 100644 --- a/pkg/request/options.go +++ b/pkg/request/options.go @@ -15,15 +15,18 @@ type Option interface { } type options struct { - timeout time.Duration - header http.Header - sign auth.Auth - signTTL int64 - ctx context.Context - contentLength int64 - masterMeta bool - endpoint *url.URL - slaveNodeID string + timeout time.Duration + header http.Header + sign auth.Auth + signTTL int64 + ctx context.Context + contentLength int64 + masterMeta bool + endpoint *url.URL + slaveNodeID string + tpsLimiterToken string + tps float64 + tpsBurst int } type optionFunc func(*options) @@ -37,6 +40,7 @@ func newDefaultOption() *options { header: http.Header{}, timeout: time.Duration(30) * time.Second, contentLength: -1, + ctx: context.Background(), } } @@ -113,3 +117,15 @@ func WithEndpoint(endpoint string) Option { o.endpoint = endpointURL }) } + +// WithTPSLimit 请求时使用全局流量限制 +func WithTPSLimit(token string, tps float64, burst int) Option { + return optionFunc(func(o *options) { + o.tpsLimiterToken = token + o.tps = tps + if burst < 1 { + burst = 1 + } + o.tpsBurst = burst + }) +} diff --git a/pkg/request/request.go b/pkg/request/request.go index eb7996c..6ee78bc 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -34,13 +34,15 @@ type Client interface { // HTTPClient 实现 Client 接口 type HTTPClient struct { - mu sync.Mutex - options *options + mu sync.Mutex + options *options + tpsLimiter TPSLimiter } func NewClient(opts ...Option) Client { client := &HTTPClient{ - options: newDefaultOption(), + options: newDefaultOption(), + tpsLimiter: globalTPSLimiter, } for _, o := range opts { @@ -126,6 +128,10 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti } } + if options.tps > 0 { + c.tpsLimiter.Limit(options.ctx, options.tpsLimiterToken, options.tps, options.tpsBurst) + } + // 发送请求 resp, err := client.Do(req) if err != nil { diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go index 1f00f7e..e54831e 100644 --- a/pkg/request/request_test.go +++ b/pkg/request/request_test.go @@ -238,3 +238,41 @@ func TestBlackHole(t *testing.T) { BlackHole(strings.NewReader("TestBlackHole")) }) } + +func TestHTTPClient_TPSLimit(t *testing.T) { + a := assert.New(t) + client := NewClient() + + finished := make(chan struct{}) + go func() { + client.Request( + "POST", + "/test", + strings.NewReader(""), + WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1), + ) + close(finished) + }() + select { + case <-finished: + case <-time.After(10 * time.Second): + a.Fail("Request should be finished instantly.") + } + + finished = make(chan struct{}) + go func() { + client.Request( + "POST", + "/test", + strings.NewReader(""), + WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1), + ) + close(finished) + }() + select { + case <-finished: + case <-time.After(2 * time.Second): + a.Fail("Request should be finished in 1 second.") + } + +} diff --git a/pkg/request/tpslimiter.go b/pkg/request/tpslimiter.go new file mode 100644 index 0000000..edea0fa --- /dev/null +++ b/pkg/request/tpslimiter.go @@ -0,0 +1,39 @@ +package request + +import ( + "context" + "golang.org/x/time/rate" + "sync" +) + +var globalTPSLimiter = NewTPSLimiter() + +type TPSLimiter interface { + Limit(ctx context.Context, token string, tps float64, burst int) +} + +func NewTPSLimiter() TPSLimiter { + return &multipleBucketLimiter{ + buckets: make(map[string]*rate.Limiter), + } +} + +// multipleBucketLimiter implements TPSLimiter with multiple bucket support. +type multipleBucketLimiter struct { + mu sync.Mutex + buckets map[string]*rate.Limiter +} + +// Limit finds the given bucket, if bucket not exist or limit is changed, +// a new bucket will be generated. +func (m *multipleBucketLimiter) Limit(ctx context.Context, token string, tps float64, burst int) { + m.mu.Lock() + bucket, ok := m.buckets[token] + if !ok || float64(bucket.Limit()) != tps || bucket.Burst() != burst { + bucket = rate.NewLimiter(rate.Limit(tps), burst) + m.buckets[token] = bucket + } + m.mu.Unlock() + + bucket.Wait(ctx) +} diff --git a/pkg/request/tpslimiter_test.go b/pkg/request/tpslimiter_test.go new file mode 100644 index 0000000..daec236 --- /dev/null +++ b/pkg/request/tpslimiter_test.go @@ -0,0 +1,46 @@ +package request + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestLimit(t *testing.T) { + a := assert.New(t) + l := NewTPSLimiter() + finished := make(chan struct{}) + go func() { + l.Limit(context.Background(), "token", 1, 1) + close(finished) + }() + select { + case <-finished: + case <-time.After(10 * time.Second): + a.Fail("Limit should be finished instantly.") + } + + finished = make(chan struct{}) + go func() { + l.Limit(context.Background(), "token", 1, 1) + close(finished) + }() + select { + case <-finished: + case <-time.After(2 * time.Second): + a.Fail("Limit should be finished in 1 second.") + } + + finished = make(chan struct{}) + go func() { + l.Limit(context.Background(), "token", 10, 1) + close(finished) + }() + select { + case <-finished: + case <-time.After(1 * time.Second): + a.Fail("Limit should be finished instantly.") + } + +}