feat: tps limit for OneDrive policy

pull/1380/head
HFO4 2 years ago
parent 4859ea6ee5
commit f083d52e17

@ -1 +1 @@
Subproject commit c0f8a7ef6ddd335b697347dce56271c3d3d8c215
Subproject commit 41f585a6f8c8f99ed4b2e279555d6b4dcdf957bc

@ -61,6 +61,10 @@ type PolicyOption struct {
ChunkSize uint64 `json:"chunk_size,omitempty"`
// 分片上传时是否需要预留空间
PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"`
// 每秒对存储端的 API 请求上限
TPSLimit float64 `json:"tps_limit,omitempty"`
// 每秒 API 请求爆发上限
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
}
// thumbSuffix 支持缩略图处理的文件扩展名

@ -52,6 +52,9 @@ func TestUserStorageCalibration_Run(t *testing.T) {
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
asserts.NoError(mock.ExpectationsWereMet())
}

@ -544,6 +544,11 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
"Content-Type": {"application/json"},
}),
request.WithContext(ctx),
request.WithTPSLimit(
fmt.Sprintf("policy_%d", client.Policy.ID),
client.Policy.OptionsSerialized.TPSLimit,
client.Policy.OptionsSerialized.TPSLimitBurst,
),
)
// 发送请求

@ -15,15 +15,18 @@ type Option interface {
}
type options struct {
timeout time.Duration
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
endpoint *url.URL
slaveNodeID string
timeout time.Duration
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
endpoint *url.URL
slaveNodeID string
tpsLimiterToken string
tps float64
tpsBurst int
}
type optionFunc func(*options)
@ -37,6 +40,7 @@ func newDefaultOption() *options {
header: http.Header{},
timeout: time.Duration(30) * time.Second,
contentLength: -1,
ctx: context.Background(),
}
}
@ -113,3 +117,15 @@ func WithEndpoint(endpoint string) Option {
o.endpoint = endpointURL
})
}
// WithTPSLimit 请求时使用全局流量限制
func WithTPSLimit(token string, tps float64, burst int) Option {
return optionFunc(func(o *options) {
o.tpsLimiterToken = token
o.tps = tps
if burst < 1 {
burst = 1
}
o.tpsBurst = burst
})
}

@ -34,13 +34,15 @@ type Client interface {
// HTTPClient 实现 Client 接口
type HTTPClient struct {
mu sync.Mutex
options *options
mu sync.Mutex
options *options
tpsLimiter TPSLimiter
}
func NewClient(opts ...Option) Client {
client := &HTTPClient{
options: newDefaultOption(),
options: newDefaultOption(),
tpsLimiter: globalTPSLimiter,
}
for _, o := range opts {
@ -126,6 +128,10 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
}
}
if options.tps > 0 {
c.tpsLimiter.Limit(options.ctx, options.tpsLimiterToken, options.tps, options.tpsBurst)
}
// 发送请求
resp, err := client.Do(req)
if err != nil {

@ -238,3 +238,41 @@ func TestBlackHole(t *testing.T) {
BlackHole(strings.NewReader("TestBlackHole"))
})
}
func TestHTTPClient_TPSLimit(t *testing.T) {
a := assert.New(t)
client := NewClient()
finished := make(chan struct{})
go func() {
client.Request(
"POST",
"/test",
strings.NewReader(""),
WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
)
close(finished)
}()
select {
case <-finished:
case <-time.After(10 * time.Second):
a.Fail("Request should be finished instantly.")
}
finished = make(chan struct{})
go func() {
client.Request(
"POST",
"/test",
strings.NewReader(""),
WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
)
close(finished)
}()
select {
case <-finished:
case <-time.After(2 * time.Second):
a.Fail("Request should be finished in 1 second.")
}
}

@ -0,0 +1,39 @@
package request
import (
"context"
"golang.org/x/time/rate"
"sync"
)
var globalTPSLimiter = NewTPSLimiter()
type TPSLimiter interface {
Limit(ctx context.Context, token string, tps float64, burst int)
}
func NewTPSLimiter() TPSLimiter {
return &multipleBucketLimiter{
buckets: make(map[string]*rate.Limiter),
}
}
// multipleBucketLimiter implements TPSLimiter with multiple bucket support.
type multipleBucketLimiter struct {
mu sync.Mutex
buckets map[string]*rate.Limiter
}
// Limit finds the given bucket, if bucket not exist or limit is changed,
// a new bucket will be generated.
func (m *multipleBucketLimiter) Limit(ctx context.Context, token string, tps float64, burst int) {
m.mu.Lock()
bucket, ok := m.buckets[token]
if !ok || float64(bucket.Limit()) != tps || bucket.Burst() != burst {
bucket = rate.NewLimiter(rate.Limit(tps), burst)
m.buckets[token] = bucket
}
m.mu.Unlock()
bucket.Wait(ctx)
}

@ -0,0 +1,46 @@
package request
import (
"context"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestLimit(t *testing.T) {
a := assert.New(t)
l := NewTPSLimiter()
finished := make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 1, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(10 * time.Second):
a.Fail("Limit should be finished instantly.")
}
finished = make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 1, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(2 * time.Second):
a.Fail("Limit should be finished in 1 second.")
}
finished = make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 10, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(1 * time.Second):
a.Fail("Limit should be finished instantly.")
}
}
Loading…
Cancel
Save