From 392c824a337fbea5586d6812b9327a3070fae150 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sat, 15 Oct 2022 16:35:02 +0800 Subject: [PATCH] feat(OneDrive): support `Retry-After` throttling control from Graph API (#280) --- pkg/filesystem/chunk/backoff/backoff.go | 53 ++++++++++++++++-- pkg/filesystem/chunk/backoff/backoff_test.go | 59 ++++++++++++++++---- pkg/filesystem/chunk/chunk.go | 11 +++- pkg/filesystem/driver/onedrive/api.go | 18 +++--- pkg/filesystem/driver/onedrive/api_test.go | 29 ++++++++++ pkg/filesystem/driver/onedrive/types.go | 5 ++ 6 files changed, 148 insertions(+), 27 deletions(-) diff --git a/pkg/filesystem/chunk/backoff/backoff.go b/pkg/filesystem/chunk/backoff/backoff.go index d15b975..95cb1b5 100644 --- a/pkg/filesystem/chunk/backoff/backoff.go +++ b/pkg/filesystem/chunk/backoff/backoff.go @@ -1,14 +1,22 @@ package backoff -import "time" +import ( + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "net/http" + "strconv" + "time" +) // Backoff used for retry sleep backoff type Backoff interface { - Next() bool + Next(err error) bool Reset() } -// ConstantBackoff implements Backoff interface with constant sleep time +// ConstantBackoff implements Backoff interface with constant sleep time. If the error +// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration. type ConstantBackoff struct { Sleep time.Duration Max int @@ -16,16 +24,51 @@ type ConstantBackoff struct { tried int } -func (c *ConstantBackoff) Next() bool { +func (c *ConstantBackoff) Next(err error) bool { c.tried++ if c.tried > c.Max { return false } - time.Sleep(c.Sleep) + var e *RetryableError + if errors.As(err, &e) && e.RetryAfter > 0 { + util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter) + time.Sleep(e.RetryAfter) + } else { + time.Sleep(c.Sleep) + } + return true } func (c *ConstantBackoff) Reset() { c.tried = 0 } + +type RetryableError struct { + Err error + RetryAfter time.Duration +} + +// NewRetryableErrorFromHeader constructs a new RetryableError from http response header +// and existing error. +func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError { + retryAfter := header.Get("retry-after") + if retryAfter == "" { + retryAfter = "0" + } + + res := &RetryableError{ + Err: err, + } + + if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil { + res.RetryAfter = time.Duration(retryAfterSecond) * time.Second + } + + return res +} + +func (e *RetryableError) Error() string { + return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err) +} diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go index 6419c71..0fda534 100644 --- a/pkg/filesystem/chunk/backoff/backoff_test.go +++ b/pkg/filesystem/chunk/backoff/backoff_test.go @@ -1,7 +1,9 @@ package backoff import ( + "errors" "github.com/stretchr/testify/assert" + "net/http" "testing" "time" ) @@ -9,14 +11,51 @@ import ( func TestConstantBackoff_Next(t *testing.T) { a := assert.New(t) - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next()) - a.True(b.Next()) - a.True(b.Next()) - a.False(b.Next()) - b.Reset() - a.True(b.Next()) - a.True(b.Next()) - a.True(b.Next()) - a.False(b.Next()) + // General error + { + err := errors.New("error") + b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + b.Reset() + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + } + + // Retryable error + { + err := &RetryableError{RetryAfter: time.Duration(1)} + b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + b.Reset() + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + } + +} + +func TestNewRetryableErrorFromHeader(t *testing.T) { + a := assert.New(t) + // no retry-after header + { + err := NewRetryableErrorFromHeader(nil, http.Header{}) + a.Empty(err.RetryAfter) + } + + // with retry-after header + { + header := http.Header{} + header.Add("retry-after", "120") + err := NewRetryableErrorFromHeader(nil, header) + a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter) + } } diff --git a/pkg/filesystem/chunk/chunk.go b/pkg/filesystem/chunk/chunk.go index 24e50a1..cf790f6 100644 --- a/pkg/filesystem/chunk/chunk.go +++ b/pkg/filesystem/chunk/chunk.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/util" "io" "os" @@ -66,7 +67,7 @@ func (c *ChunkGroup) TempAvailable() bool { // Process a chunk with retry logic func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { - reader := io.LimitReader(c.file, int64(c.chunkSize)) + reader := io.LimitReader(c.file, c.Length()) // If useBuffer is enabled, tee the reader to a temp file if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() { @@ -90,13 +91,17 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { } util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name()) - reader = c.bufferTemp + reader = io.NopCloser(c.bufferTemp) } } err := processor(c, reader) if err != nil { - if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next() { + if c.enableRetryBuffer { + request.BlackHole(reader) + } + + if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) { if c.file.Seekable() { if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil { return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err) diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index e1a2219..2ec1663 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "io" - "io/ioutil" "net/http" "net/url" "path" @@ -51,11 +50,6 @@ func (info *FileInfo) GetSourcePath() string { ) } -// Error 实现error接口 -func (err RespError) Error() string { - return err.APIError.Message -} - func (client *Client) getRequestURL(api string, opts ...Option) string { options := newDefaultOption() for _, o := range opts { @@ -530,7 +524,7 @@ func sysError(err error) *RespError { }} } -func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, *RespError) { +func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) { // 获取凭证 err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave") if err != nil { @@ -579,15 +573,21 @@ func (client *Client) request(ctx context.Context, method string, url string, bo util.Log().Debug("Onedrive returns unknown response: %s", respBody) return "", sysError(decodeErr) } + + if res.Response.StatusCode == 429 { + util.Log().Warning("OneDrive request is throttled.") + return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header) + } + return "", &errResp } return respBody, nil } -func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, *RespError) { +func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) { // 发送请求 - bodyReader := ioutil.NopCloser(strings.NewReader(body)) + bodyReader := io.NopCloser(strings.NewReader(body)) return client.request(ctx, method, url, bodyReader, request.WithContentLength(int64(len(body))), ) diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go index 3ca7a33..a675548 100644 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -112,6 +112,35 @@ func TestRequest(t *testing.T) { asserts.Equal("error msg", err.Error()) } + // OneDrive返回429错误 + { + header := http.Header{} + header.Add("retry-after", "120") + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 429, + Header: header, + Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), + }, + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + var retryErr *backoff.RetryableError + asserts.ErrorAs(err, &retryErr) + asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter) + } + // OneDrive返回未知响应 { clientMock := ClientMock{} diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filesystem/driver/onedrive/types.go index 2a4307f..2a2ea4c 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filesystem/driver/onedrive/types.go @@ -133,3 +133,8 @@ type Site struct { func init() { gob.Register(Credential{}) } + +// Error 实现error接口 +func (err RespError) Error() string { + return err.APIError.Message +}