diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go new file mode 100644 index 0000000..6419c71 --- /dev/null +++ b/pkg/filesystem/chunk/backoff/backoff_test.go @@ -0,0 +1,22 @@ +package backoff + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestConstantBackoff_Next(t *testing.T) { + a := assert.New(t) + + b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} + a.True(b.Next()) + a.True(b.Next()) + a.True(b.Next()) + a.False(b.Next()) + b.Reset() + a.True(b.Next()) + a.True(b.Next()) + a.True(b.Next()) + a.False(b.Next()) +} diff --git a/pkg/filesystem/chunk/chunk.go b/pkg/filesystem/chunk/chunk.go index 5313558..24e50a1 100644 --- a/pkg/filesystem/chunk/chunk.go +++ b/pkg/filesystem/chunk/chunk.go @@ -42,9 +42,13 @@ func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Back c.chunkSize = c.fileInfo.Size } - c.chunkNum = c.fileInfo.Size / c.chunkSize - if c.fileInfo.Size%c.chunkSize != 0 || c.fileInfo.Size == 0 { - c.chunkNum++ + if c.fileInfo.Size == 0 { + c.chunkNum = 1 + } else { + c.chunkNum = c.fileInfo.Size / c.chunkSize + if c.fileInfo.Size%c.chunkSize != 0 { + c.chunkNum++ + } } return c @@ -95,7 +99,7 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next() { if c.file.Seekable() { if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil { - return fmt.Errorf("failed to seek back to chunk start: %w, last error: %w", seekErr, err) + return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err) } } @@ -115,7 +119,7 @@ func (c *ChunkGroup) Start() int64 { return int64(uint64(c.Index()) * c.chunkSize) } -// Total returns the total length current chunk +// Total returns the total length func (c *ChunkGroup) Total() int64 { return int64(c.fileInfo.Size) } diff --git a/pkg/filesystem/chunk/chunk_test.go b/pkg/filesystem/chunk/chunk_test.go new file mode 100644 index 0000000..c6af9d9 --- /dev/null +++ b/pkg/filesystem/chunk/chunk_test.go @@ -0,0 +1,250 @@ +package chunk + +import ( + "errors" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/stretchr/testify/assert" + "io" + "os" + "strings" + "testing" +) + +func TestNewChunkGroup(t *testing.T) { + a := assert.New(t) + + testCases := []struct { + fileSize uint64 + chunkSize uint64 + expectedInnerChunkSize uint64 + expectedChunkNum uint64 + expectedInfo [][2]int //Start, Index,Length + }{ + {10, 0, 10, 1, [][2]int{{0, 10}}}, + {0, 0, 0, 1, [][2]int{{0, 0}}}, + {0, 10, 10, 1, [][2]int{{0, 0}}}, + {50, 10, 10, 5, [][2]int{ + {0, 10}, + {10, 10}, + {20, 10}, + {30, 10}, + {40, 10}, + }}, + {50, 50, 50, 1, [][2]int{ + {0, 50}, + }}, + + {50, 15, 15, 4, [][2]int{ + {0, 15}, + {15, 15}, + {30, 15}, + {45, 5}, + }}, + } + + for index, testCase := range testCases { + file := &fsctx.FileStream{Size: testCase.fileSize} + chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true) + a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), + "TestCase:%d,ChunkNum()", index) + a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize, + "TestCase:%d,InnerChunkSize()", index) + a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), + "TestCase:%d,len(Chunks)", index) + a.EqualValues(testCase.fileSize, chunkGroup.Total()) + + for cIndex, info := range testCase.expectedInfo { + a.True(chunkGroup.Next()) + a.EqualValues(info[1], chunkGroup.Length(), + "TestCase:%d,Chunks[%d].Length()", index, cIndex) + a.EqualValues(info[0], chunkGroup.Start(), + "TestCase:%d,Chunks[%d].Start()", index, cIndex) + + a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(), + "TestCase:%d,Chunks[%d].IsLast()", index, cIndex) + + a.NotEmpty(chunkGroup.RangeHeader()) + } + a.False(chunkGroup.Next()) + } +} + +func TestChunkGroup_TempAvailablet(t *testing.T) { + a := assert.New(t) + + file := &fsctx.FileStream{Size: 1} + c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true) + a.False(c.TempAvailable()) + + f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*") + defer func() { + f.Close() + os.Remove(f.Name()) + }() + a.NoError(err) + c.bufferTemp = f + + a.False(c.TempAvailable()) + f.Write([]byte("1")) + a.True(c.TempAvailable()) + +} + +func TestChunkGroup_Process(t *testing.T) { + a := assert.New(t) + file := &fsctx.FileStream{Size: 10} + + // success + { + file.File = io.NopCloser(strings.NewReader("1234567890")) + c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true) + count := 0 + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("12345", string(res)) + return nil + })) + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("67890", string(res)) + return nil + })) + a.False(c.Next()) + a.Equal(2, count) + } + + // retry, read from buffer file + { + file.File = io.NopCloser(strings.NewReader("1234567890")) + c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true) + count := 0 + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("12345", string(res)) + return nil + })) + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("67890", string(res)) + if count == 2 { + return errors.New("error") + } + return nil + })) + a.False(c.Next()) + a.Equal(3, count) + } + + // retry, read from seeker + { + f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") + f.Write([]byte("1234567890")) + f.Seek(0, 0) + defer func() { + f.Close() + os.Remove(f.Name()) + }() + file.File = f + file.Seeker = f + c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) + count := 0 + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("12345", string(res)) + return nil + })) + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("67890", string(res)) + if count == 2 { + return errors.New("error") + } + return nil + })) + a.False(c.Next()) + a.Equal(3, count) + } + + // retry, seek error + { + f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") + f.Write([]byte("1234567890")) + f.Seek(0, 0) + defer func() { + f.Close() + os.Remove(f.Name()) + }() + file.File = f + file.Seeker = f + c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) + count := 0 + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("12345", string(res)) + return nil + })) + a.True(c.Next()) + f.Close() + a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + if count == 2 { + return errors.New("error") + } + return nil + })) + a.False(c.Next()) + a.Equal(2, count) + } + + // retry, finally error + { + f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") + f.Write([]byte("1234567890")) + f.Seek(0, 0) + defer func() { + f.Close() + os.Remove(f.Name()) + }() + file.File = f + file.Seeker = f + c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) + count := 0 + a.True(c.Next()) + a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + res, err := io.ReadAll(chunk) + a.NoError(err) + a.EqualValues("12345", string(res)) + return nil + })) + a.True(c.Next()) + a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { + count++ + return errors.New("error") + })) + a.False(c.Next()) + a.Equal(1, count) + } +} diff --git a/pkg/filesystem/driver/local/handler.go b/pkg/filesystem/driver/local/handler.go index 7e665c8..e5e8994 100644 --- a/pkg/filesystem/driver/local/handler.go +++ b/pkg/filesystem/driver/local/handler.go @@ -161,7 +161,7 @@ func (handler Driver) Truncate(ctx context.Context, src string, size uint64) err util.Log().Warning("截断文件 [%s] 至 [%d]", src, size) out, err := os.OpenFile(src, os.O_WRONLY, Perm) if err != nil { - util.Log().Warning("无法打开或创建文件,%s", err) + util.Log().Warning("无法打开文件,%s", err) return err } diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go index dac4d54..9ce5fe7 100644 --- a/pkg/filesystem/driver/local/handler_test.go +++ b/pkg/filesystem/driver/local/handler_test.go @@ -4,13 +4,12 @@ import ( "context" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "io" - "io/ioutil" "net/url" "os" "strings" @@ -20,42 +19,64 @@ import ( func TestHandler_Put(t *testing.T) { asserts := assert.New(t) handler := Driver{} - ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) - os.Remove(util.RelativePath("test/test/txt")) + + defer func() { + os.Remove(util.RelativePath("TestHandler_Put.txt")) + os.Remove(util.RelativePath("inner/TestHandler_Put.txt")) + }() testCases := []struct { - file io.ReadCloser - dst string - err bool + file fsctx.FileHeader + errContains string }{ - { - file: ioutil.NopCloser(strings.NewReader("test input file")), - dst: "test/test/txt", - err: false, - }, - { - file: ioutil.NopCloser(strings.NewReader("test input file")), - dst: "test/test/txt", - err: true, - }, - { - file: ioutil.NopCloser(strings.NewReader("test input file")), - dst: "/notexist:/S.TXT", - err: true, - }, + {&fsctx.FileStream{ + SavePath: "TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("")), + }, ""}, + {&fsctx.FileStream{ + SavePath: "TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("")), + }, "物理同名文件已存在或不可用"}, + {&fsctx.FileStream{ + SavePath: "inner/TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("")), + }, ""}, + {&fsctx.FileStream{ + Mode: fsctx.Append | fsctx.Overwrite, + SavePath: "inner/TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("123")), + }, ""}, + {&fsctx.FileStream{ + AppendStart: 10, + Mode: fsctx.Append | fsctx.Overwrite, + SavePath: "inner/TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("123")), + }, "未上传完成的文件分片与预期大小不一致"}, + {&fsctx.FileStream{ + Mode: fsctx.Append | fsctx.Overwrite, + SavePath: "inner/TestHandler_Put.txt", + File: io.NopCloser(strings.NewReader("123")), + }, ""}, } for _, testCase := range testCases { - err := handler.Put(ctx, testCase.file, testCase.dst, 15) - if testCase.err { + err := handler.Put(context.Background(), testCase.file) + if testCase.errContains != "" { asserts.Error(err) + asserts.Contains(err.Error(), testCase.errContains) } else { asserts.NoError(err) - asserts.True(util.Exists(util.RelativePath(testCase.dst))) + asserts.True(util.Exists(util.RelativePath(testCase.file.Info().SavePath))) } } } +func TestDriver_TruncateFailed(t *testing.T) { + a := assert.New(t) + h := Driver{} + a.Error(h.Truncate(context.Background(), "TestDriver_TruncateFailed", 0)) +} + func TestHandler_Delete(t *testing.T) { asserts := assert.New(t) handler := Driver{} @@ -116,7 +137,7 @@ func TestHandler_Thumb(t *testing.T) { asserts := assert.New(t) handler := Driver{} ctx := context.Background() - file, err := os.Create(util.RelativePath("TestHandler_Thumb" + conf.ThumbConfig.FileSuffix)) + file, err := os.Create(util.RelativePath("TestHandler_Thumb._thumb")) asserts.NoError(err) file.Close() @@ -160,6 +181,25 @@ func TestHandler_Source(t *testing.T) { asserts.Contains(sourceURL, "https://cloudreve.org") } + // 下载 + { + file := model.File{ + Model: gorm.Model{ + ID: 1, + }, + Name: "test.jpg", + } + ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0, true, 0) + asserts.NoError(err) + asserts.NotEmpty(sourceURL) + asserts.Contains(sourceURL, "sign=") + asserts.Contains(sourceURL, "download") + asserts.Contains(sourceURL, "https://cloudreve.org") + } + // 无法获取上下文 { baseURL, err := url.Parse("https://cloudreve.org") @@ -241,10 +281,29 @@ func TestHandler_GetDownloadURL(t *testing.T) { func TestHandler_Token(t *testing.T) { asserts := assert.New(t) - handler := Driver{} + handler := Driver{ + Policy: &model.Policy{}, + } ctx := context.Background() - _, err := handler.Token(ctx, 10, "123") + upSession := &serializer.UploadSession{SavePath: "TestHandler_Token"} + _, err := handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) asserts.NoError(err) + + file, _ := os.Create("TestHandler_Token") + defer func() { + file.Close() + os.Remove("TestHandler_Token") + }() + + _, err = handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) + asserts.Error(err) + asserts.Contains(err.Error(), "already exist") +} + +func TestDriver_CancelToken(t *testing.T) { + a := assert.New(t) + handler := Driver{} + a.NoError(handler.CancelToken(context.Background(), &serializer.UploadSession{})) } func TestDriver_List(t *testing.T) { diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 2723626..5dc2ec9 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -221,16 +221,8 @@ func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL stri return &uploadSession, nil } -var index = 0 - // UploadChunk 上传分片 func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) { - index++ - if index == 1 || index == 2 { - request.BlackHole(content) - return nil, errors.New("error") - } - res, err := client.request( ctx, "PUT", uploadURL, content, request.WithContentLength(current.Length()), @@ -331,16 +323,6 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read request.WithTimeout(time.Duration(150)*time.Second), ) if err != nil { - retried := 0 - if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok { - retried = v - } - if retried < model.GetIntSetting("chunk_retries", 5) { - retried++ - util.Log().Debug("文件[%s]上传失败[%s],5秒钟后重试", dst, err) - time.Sleep(time.Duration(5) * time.Second) - return client.SimpleUpload(context.WithValue(ctx, fsctx.RetryCtx, retried), dst, body, size, opts...) - } return nil, err } diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go index 8acc6db..fb6393d 100644 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -4,6 +4,11 @@ import ( "context" "errors" "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "io" "io/ioutil" "net/http" "strings" @@ -12,7 +17,6 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" @@ -307,6 +311,31 @@ func TestClient_Meta(t *testing.T) { asserts.NotNil(res) asserts.Equal("123321", res.Name) } + + // 返回正常, 使用资源id + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.Meta(context.Background(), "123321", "123") + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("123321", res.Name) + } } func TestClient_CreateUploadSession(t *testing.T) { @@ -442,9 +471,11 @@ func TestClient_UploadChunk(t *testing.T) { client, _ := NewClient(&model.Policy{}) client.Credential.AccessToken = "AccessToken" client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + cg := chunk.NewChunkGroup(&fsctx.FileStream{Size: 15}, 10, &backoff.ConstantBackoff{}, false) // 非最后分片,正常 { + cg.Next() client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() clientMock := ClientMock{} clientMock.On( @@ -453,6 +484,10 @@ func TestClient_UploadChunk(t *testing.T) { "http://dev.com", testMock.Anything, testMock.Anything, + testMock.Anything, + testMock.Anything, + testMock.Anything, + testMock.Anything, ).Return(&request.Response{ Err: nil, Response: &http.Response{ @@ -461,13 +496,7 @@ func TestClient_UploadChunk(t *testing.T) { }, }) client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ - Offset: 0, - ChunkSize: 10, - Total: 100, - Retried: 0, - Data: []byte("12313121231312"), - }) + res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) clientMock.AssertExpectations(t) asserts.NoError(err) asserts.Equal("http://dev.com/2", res.UploadURL) @@ -491,13 +520,7 @@ func TestClient_UploadChunk(t *testing.T) { }, }) client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ - Offset: 0, - ChunkSize: 10, - Total: 100, - Retried: 0, - Data: []byte("12313112313122"), - }) + res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) clientMock.AssertExpectations(t) asserts.Error(err) asserts.Nil(res) @@ -505,6 +528,7 @@ func TestClient_UploadChunk(t *testing.T) { // 最后分片,正常 { + cg.Next() client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() clientMock := ClientMock{} clientMock.On( @@ -521,19 +545,13 @@ func TestClient_UploadChunk(t *testing.T) { }, }) client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ - Offset: 95, - ChunkSize: 5, - Total: 100, - Retried: 0, - Data: []byte("1231312"), - }) + res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) clientMock.AssertExpectations(t) asserts.NoError(err) asserts.Nil(res) } - // 最后分片,第一次失败,重试后成功 + // 最后分片,失败 { cache.Set("setting_chunk_retries", "1", 0) client.Credential.ExpiresIn = 0 @@ -542,32 +560,11 @@ func TestClient_UploadChunk(t *testing.T) { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() }() clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) client.Request = clientMock - chunk := &Chunk{ - Offset: 95, - ChunkSize: 5, - Total: 100, - Retried: 0, - Data: []byte("1231312"), - } - res, err := client.UploadChunk(context.Background(), "http://dev.com", chunk) + res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) clientMock.AssertExpectations(t) - asserts.NoError(err) + asserts.Error(err) asserts.Nil(res) - asserts.EqualValues(1, chunk.Retried) } } @@ -576,16 +573,21 @@ func TestClient_Upload(t *testing.T) { client, _ := NewClient(&model.Policy{}) client.Credential.AccessToken = "AccessToken" client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) + ctx := context.Background() + cache.Set("setting_chunk_retries", "1", 0) + cache.Set("setting_use_temp_chunk_buffer", "false", 0) // 小文件,简单上传,失败 { client.Credential.ExpiresIn = 0 - err := client.Upload(ctx, "123.jpg", 3, strings.NewReader("123")) + err := client.Upload(ctx, &fsctx.FileStream{ + Size: 5, + File: io.NopCloser(strings.NewReader("12345")), + }) asserts.Error(err) } - // 上下文取消 + // 无法创建分片会话 { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() clientMock := ClientMock{} @@ -598,20 +600,20 @@ func TestClient_Upload(t *testing.T) { ).Return(&request.Response{ Err: nil, Response: &http.Response{ - StatusCode: 200, + StatusCode: 400, Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), }, }) client.Request = clientMock - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err := client.Upload(ctx, "123.jpg", 15*1024*1024, strings.NewReader("123")) + err := client.Upload(context.Background(), &fsctx.FileStream{ + Size: SmallFileSize + 1, + File: io.NopCloser(strings.NewReader("12345")), + }) clientMock.AssertExpectations(t) asserts.Error(err) - asserts.Equal(ErrClientCanceled, err) } - // 无法创建分片会话 + // 分片上传失败 { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() clientMock := ClientMock{} @@ -621,6 +623,19 @@ func TestClient_Upload(t *testing.T) { testMock.Anything, testMock.Anything, testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + clientMock.On( + "Request", + "PUT", + testMock.Anything, + testMock.Anything, + testMock.Anything, ).Return(&request.Response{ Err: nil, Response: &http.Response{ @@ -629,9 +644,13 @@ func TestClient_Upload(t *testing.T) { }, }) client.Request = clientMock - err := client.Upload(context.Background(), "123.jpg", 15*1024*1024, strings.NewReader("123")) + err := client.Upload(context.Background(), &fsctx.FileStream{ + Size: SmallFileSize + 1, + File: io.NopCloser(strings.NewReader("12345")), + }) clientMock.AssertExpectations(t) asserts.Error(err) + asserts.Contains(err.Error(), "failed to upload chunk") } } @@ -643,7 +662,7 @@ func TestClient_SimpleUpload(t *testing.T) { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() cache.Set("setting_chunk_retries", "1", 0) - // 请求失败,并重试 + // 请求失败 { client.Credential.ExpiresIn = 0 res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) @@ -651,7 +670,6 @@ func TestClient_SimpleUpload(t *testing.T) { asserts.Nil(res) } - cache.Set("setting_chunk_retries", "0", 0) // 返回未知响应 { client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() @@ -988,7 +1006,7 @@ func TestClient_MonitorUpload(t *testing.T) { asserts.NotPanics(func() { go func() { time.Sleep(time.Duration(1) * time.Second) - FinishCallback("key") + mq.GlobalMQ.Publish("key", mq.Message{}) }() client.MonitorUpload("url", "key", "path", 10, 10) }) diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go index c2b00ae..7700e7a 100644 --- a/pkg/filesystem/driver/onedrive/handler_test.go +++ b/pkg/filesystem/driver/onedrive/handler_test.go @@ -4,6 +4,9 @@ import ( "context" "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/jinzhu/gorm" "io" "io/ioutil" "net/http" @@ -12,51 +15,23 @@ import ( "testing" "time" - "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" ) func TestDriver_Token(t *testing.T) { asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - - // 无法获取文件路径 - { - ctx := context.WithValue(context.Background(), fsctx.FileSizeCtx, uint64(10)) - res, err := handler.Token(ctx, 10, "key", nil) - asserts.Error(err) - asserts.Equal(serializer.UploadCredential{}, res) - } - - // 无法获取文件大小 - { - ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") - res, err := handler.Token(ctx, 10, "key", nil) - asserts.Error(err) - asserts.Equal(serializer.UploadCredential{}, res) - } - - // 小文件成功 - { - ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") - ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(10)) - res, err := handler.Token(ctx, 10, "key", nil) - asserts.NoError(err) - asserts.Equal(serializer.UploadCredential{}, res) - } + h, _ := NewDriver(&model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }) + handler := h.(Driver) // 分片上传 失败 { @@ -78,11 +53,9 @@ func TestDriver_Token(t *testing.T) { }, }) handler.Client.Request = clientMock - ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") - ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024)) - res, err := handler.Token(ctx, 10, "key", nil) + res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) asserts.Error(err) - asserts.Equal(serializer.UploadCredential{}, res) + asserts.Nil(res) } // 分片上传 成功 @@ -108,15 +81,13 @@ func TestDriver_Token(t *testing.T) { }, }) handler.Client.Request = clientMock - ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") - ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024)) go func() { time.Sleep(time.Duration(1) * time.Second) - FinishCallback("key") + mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{}) }() - res, err := handler.Token(ctx, 10, "key", nil) + res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{}) asserts.NoError(err) - asserts.Equal("123321", res.Policy) + asserts.Equal("123321", res.UploadURLs[0]) } } @@ -295,12 +266,8 @@ func TestDriver_Thumb(t *testing.T) { // 失败 { ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20}) - ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{}) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() + ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}}) res, err := handler.Thumb(ctx, "123.jpg") - asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) asserts.Empty(res.URL) } @@ -308,7 +275,6 @@ func TestDriver_Thumb(t *testing.T) { // 上下文错误 { _, err := handler.Thumb(context.Background(), "123.jpg") - asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) } } @@ -329,7 +295,6 @@ func TestDriver_Delete(t *testing.T) { // 失败 { _, err := handler.Delete(context.Background(), []string{"1"}) - asserts.NoError(mock.ExpectationsWereMet()) asserts.Error(err) } @@ -350,7 +315,7 @@ func TestDriver_Put(t *testing.T) { // 失败 { - err := handler.Put(context.Background(), ioutil.NopCloser(strings.NewReader("")), "dst", 0) + err := handler.Put(context.Background(), &fsctx.FileStream{}) asserts.Error(err) } } @@ -418,3 +383,55 @@ func TestDriver_Get(t *testing.T) { asserts.NoError(err) asserts.Equal("123", string(content)) } + +func TestDriver_replaceSourceHost(t *testing.T) { + tests := []struct { + name string + origin string + cdn string + want string + wantErr bool + }{ + {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false}, + {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false}, + {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true}, + {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy := &model.Policy{} + policy.OptionsSerialized.OdProxy = tt.cdn + handler := Driver{ + Policy: policy, + } + got, err := handler.replaceSourceHost(tt.origin) + if (err != nil) != tt.wantErr { + t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDriver_CancelToken(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 失败 + { + err := handler.CancelToken(context.Background(), &serializer.UploadSession{}) + asserts.Error(err) + } +} diff --git a/pkg/filesystem/driver/onedrive/handller_test.go b/pkg/filesystem/driver/onedrive/handller_test.go deleted file mode 100644 index 147a547..0000000 --- a/pkg/filesystem/driver/onedrive/handller_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package onedrive - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "testing" -) - -func TestDriver_replaceSourceHost(t *testing.T) { - tests := []struct { - name string - origin string - cdn string - want string - wantErr bool - }{ - {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false}, - {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false}, - {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true}, - {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - policy := &model.Policy{} - policy.OptionsSerialized.OdProxy = tt.cdn - handler := Driver{ - Policy: policy, - } - got, err := handler.replaceSourceHost(tt.origin) - if (err != nil) != tt.wantErr { - t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filesystem/driver/onedrive/types.go index aefa638..2a4307f 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filesystem/driver/onedrive/types.go @@ -98,15 +98,6 @@ type ListResponse struct { Context string `json:"@odata.context"` } -// Chunk 文件分片 -type Chunk struct { - Offset int - ChunkSize int - Total int - Retried int - Data []byte -} - // oauthEndpoint OAuth接口地址 type oauthEndpoint struct { token url.URL @@ -142,8 +133,3 @@ type Site struct { func init() { gob.Register(Credential{}) } - -// IsLast 返回是否为最后一个分片 -func (chunk *Chunk) IsLast() bool { - return chunk.Total-chunk.Offset == chunk.ChunkSize -}