diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 09f4eed..9ce6fed 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -393,7 +393,7 @@ func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size ui cache.Deletes([]string{callbackKey}, "callback_") _, err = client.Delete(context.Background(), []string{path}) if err != nil { - util.Log().Warning("无法删除未回掉的文件,%s", err) + util.Log().Warning("无法删除未回调的文件,%s", err) } } return @@ -486,7 +486,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo decodeErr error ) // 如果有错误 - if res.Response.StatusCode < 200 && res.Response.StatusCode >= 300 { + if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 { decodeErr = json.Unmarshal([]byte(respBody), &errResp) if decodeErr != nil { return "", sysError(decodeErr) diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go new file mode 100644 index 0000000..e3511ac --- /dev/null +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -0,0 +1,1012 @@ +package onedrive + +import ( + "context" + "errors" + "fmt" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/request" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" +) + +func TestRequest(t *testing.T) { + asserts := assert.New(t) + client := Client{ + Policy: &model.Policy{}, + ClientID: "TestRequest", + Credential: &Credential{ + ExpiresIn: time.Now().Add(time.Duration(100) * time.Hour).Unix(), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + }, + } + + // 请求发送失败 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: errors.New("error"), + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + asserts.Equal("error", err.Error()) + } + + // 无法更新凭证 + { + client.Credential.RefreshToken = "" + client.Credential.AccessToken = "" + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + asserts.Error(err) + asserts.Empty(res) + client.Credential.RefreshToken = "RefreshToken" + client.Credential.AccessToken = "AccessToken" + } + + // 无法获取响应正文 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(mockReader("")), + }, + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + } + + // OneDrive返回错误 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), + }, + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + asserts.Equal("error msg", err.Error()) + } + + // OneDrive返回未知响应 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + } +} + +func TestFileInfo_GetSourcePath(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + fileInfo := FileInfo{ + Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg", + ParentReference: parentReference{ + Path: "/drive/root:/123/321", + }, + } + asserts.Equal("123/321/文件名.jpg", fileInfo.GetSourcePath()) + } + + // 失败 + { + fileInfo := FileInfo{ + Name: "%e6%96%87%e4%bb%b6%e5%90%8g.jpg", + ParentReference: parentReference{ + Path: "/drive/root:/123/321", + }, + } + asserts.Equal("", fileInfo.GetSourcePath()) + } +} + +func TestClient_GetRequestURL(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + + // 出错 + { + client.Endpoints.EndpointURL = string([]byte{0x7f}) + asserts.Equal("", client.getRequestURL("123")) + } +} + +func TestClient_Meta(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.Meta(context.Background(), "", "123") + asserts.Error(err) + asserts.Nil(res) + + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.Meta(context.Background(), "", "123") + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.Meta(context.Background(), "", "123") + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("123321", res.Name) + } +} + +func TestClient_CreateUploadSession(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.CreateUploadSession(context.Background(), "123.jpg") + asserts.Error(err) + asserts.Empty(res) + + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.CreateUploadSession(context.Background(), "123.jpg") + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.CreateUploadSession(context.Background(), "123.jpg", WithConflictBehavior("fail")) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("123321", res) + } +} + +func TestClient_GetUploadSessionStatus(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") + asserts.Error(err) + asserts.Empty(res) + + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("123321", res.UploadURL) + } +} + +func TestClient_UploadChunk(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 非最后分片,正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)), + }, + }) + client.Request = clientMock + res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ + Offset: 0, + ChunkSize: 10, + Total: 100, + Retried: 0, + Reader: strings.NewReader("1231312"), + }) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.Equal("http://dev.com/2", res.UploadURL) + } + + // 非最后分片,异常响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ + Offset: 0, + ChunkSize: 10, + Total: 100, + Retried: 0, + Reader: strings.NewReader("1231312"), + }) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } + + // 最后分片,正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.UploadChunk(context.Background(), "http://dev.com", &Chunk{ + Offset: 95, + ChunkSize: 5, + Total: 100, + Retried: 0, + Reader: strings.NewReader("1231312"), + }) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.Nil(res) + } + + // 最后分片,第一次失败,重试后成功 + { + cache.Set("setting_onedrive_chunk_retries", "1", 0) + client.Credential.ExpiresIn = 0 + go func() { + time.Sleep(time.Duration(2) * time.Second) + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + }() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + chunk := &Chunk{ + Offset: 95, + ChunkSize: 5, + Total: 100, + Retried: 0, + Reader: strings.NewReader("1231312"), + } + res, err := client.UploadChunk(context.Background(), "http://dev.com", chunk) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.Nil(res) + asserts.EqualValues(1, chunk.Retried) + } +} + +func TestClient_Upload(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 小文件,简单上传,失败 + { + client.Credential.ExpiresIn = 0 + err := client.Upload(context.Background(), "123.jpg", 3, strings.NewReader("123")) + asserts.Error(err) + } + + // 大文件 分两个分片 成功 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + clientMock.On( + "Request", + "PUT", + "123321", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)), + }, + }) + client.Request = clientMock + + err := client.Upload(context.Background(), "123.jpg", 15*1024*1024, strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.NoError(err) + } + + // 大文件 分两个分片 失败 + { + cache.Set("setting_onedrive_chunk_retries", "0", 0) + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + clientMock.On( + "Request", + "PUT", + "123321", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)), + }, + }) + client.Request = clientMock + + err := client.Upload(context.Background(), "123.jpg", 15*1024*1024, strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.Error(err) + } + + // 上下文取消 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + client.Request = clientMock + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := client.Upload(ctx, "123.jpg", 15*1024*1024, strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Equal(ErrClientCanceled, err) + } + + // 无法创建分片会话 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + client.Request = clientMock + err := client.Upload(context.Background(), "123.jpg", 15*1024*1024, strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.Error(err) + } + +} + +func TestClient_SimpleUpload(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + asserts.Error(err) + asserts.Nil(res) + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "PUT", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123")) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("123321", res.Name) + } +} + +func TestClient_DeleteUploadSession(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + err := client.DeleteUploadSession(context.Background(), "123.jpg") + asserts.Error(err) + + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "DELETE", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 204, + Body: ioutil.NopCloser(strings.NewReader(``)), + }, + }) + client.Request = clientMock + err := client.DeleteUploadSession(context.Background(), "123.jpg") + clientMock.AssertExpectations(t) + asserts.NoError(err) + } +} + +func TestClient_Delete(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) + asserts.Error(err) + asserts.Len(res, 3) + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Len(res, 3) + } + + // 成功2两个文件 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), + }, + }) + client.Request = clientMock + res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Equal([]string{"2"}, res) + } +} + +func TestClient_GetThumbURL(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) + asserts.Error(err) + asserts.Empty(res) + } + + // 未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) + asserts.Error(err) + asserts.Empty(res) + } + + // 世纪互联 成功 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + client.Endpoints.isInChina = true + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"url":"thumb"}`)), + }, + }) + client.Request = clientMock + res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) + asserts.NoError(err) + asserts.Equal("thumb", res) + } + + // 非世纪互联 成功 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + client.Endpoints.isInChina = false + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"c1x1_Crop":{"url":"thumb"}}]}`)), + }, + }) + client.Request = clientMock + res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) + asserts.NoError(err) + asserts.Equal("thumb", res) + } +} + +func TestClient_MonitorUpload(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 客户端完成回调 + { + cache.Set("setting_onedrive_monitor_timeout", "600", 0) + cache.Set("setting_onedrive_callback_check", "20", 0) + asserts.NotPanics(func() { + go func() { + time.Sleep(time.Duration(1) * time.Second) + FinishCallback("key") + }() + client.MonitorUpload("url", "key", "path", 10, 10) + }) + } + + // 上传会话到期,仍未完成上传,创建占位符 + { + cache.Set("setting_onedrive_monitor_timeout", "600", 0) + cache.Set("setting_onedrive_callback_check", "20", 0) + asserts.NotPanics(func() { + client.MonitorUpload("url", "key", "path", 10, 0) + }) + } + + fmt.Println("测试:上传已完成,未发送回调") + // 上传已完成,未发送回调 + { + cache.Set("setting_onedrive_monitor_timeout", "0", 0) + cache.Set("setting_onedrive_callback_check", "0", 0) + + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + client.Credential.AccessToken = "1" + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 404, + Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), + }, + }) + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 404, + Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), + }, + }) + client.Request = clientMock + cache.Set("callback_key3", "ok", 0) + + asserts.NotPanics(func() { + client.MonitorUpload("url", "key3", "path", 10, 10) + }) + + clientMock.AssertExpectations(t) + } + + fmt.Println("测试:上传仍未开始") + // 上传仍未开始 + { + cache.Set("setting_onedrive_monitor_timeout", "0", 0) + cache.Set("setting_onedrive_callback_check", "0", 0) + + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + client.Credential.AccessToken = "1" + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"nextExpectedRanges":["0-"]}`)), + }, + }) + clientMock.On( + "Request", + "DELETE", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(``)), + }, + }) + clientMock.On( + "Request", + "PUT", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{}`)), + }, + }) + client.Request = clientMock + + asserts.NotPanics(func() { + client.MonitorUpload("url", "key4", "path", 10, 10) + }) + + clientMock.AssertExpectations(t) + } + +} diff --git a/pkg/filesystem/driver/onedrive/client_test.go b/pkg/filesystem/driver/onedrive/client_test.go index 01ee4db..c298a82 100644 --- a/pkg/filesystem/driver/onedrive/client_test.go +++ b/pkg/filesystem/driver/onedrive/client_test.go @@ -1,7 +1,31 @@ package onedrive -import "testing" +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/stretchr/testify/assert" + "testing" +) func TestNewClient(t *testing.T) { + asserts := assert.New(t) + // getOAuthEndpoint失败 + { + policy := model.Policy{ + BaseURL: string([]byte{0x7f}), + } + res, err := NewClient(&policy) + asserts.Error(err) + asserts.Nil(res) + } + // 成功 + { + policy := model.Policy{} + res, err := NewClient(&policy) + asserts.NoError(err) + asserts.NotNil(res) + asserts.NotNil(res.Credential) + asserts.NotNil(res.Endpoints) + asserts.NotNil(res.Endpoints.OAuthEndpoints) + } } diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go new file mode 100644 index 0000000..6040fb6 --- /dev/null +++ b/pkg/filesystem/driver/onedrive/handler_test.go @@ -0,0 +1,303 @@ +package onedrive + +import ( + "context" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/request" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + "time" +) + +func TestDriver_Token(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + + // 无法获取文件路径 + { + ctx := context.WithValue(context.Background(), fsctx.FileSizeCtx, uint64(10)) + res, err := handler.Token(ctx, 10, "key") + asserts.Error(err) + asserts.Equal(serializer.UploadCredential{}, res) + } + + // 无法获取文件大小 + { + ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") + res, err := handler.Token(ctx, 10, "key") + asserts.Error(err) + asserts.Equal(serializer.UploadCredential{}, res) + } + + // 小文件成功 + { + ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") + ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(10)) + res, err := handler.Token(ctx, 10, "key") + asserts.NoError(err) + asserts.Equal(serializer.UploadCredential{}, res) + } + + // 分片上传 失败 + { + cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + handler.Client.Request = clientMock + ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") + ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024)) + res, err := handler.Token(ctx, 10, "key") + asserts.Error(err) + asserts.Equal(serializer.UploadCredential{}, res) + } + + // 分片上传 成功 + { + cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) + cache.Set("setting_onedrive_monitor_timeout", "600", 0) + cache.Set("setting_onedrive_callback_check", "20", 0) + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + handler.Client.Credential.AccessToken = "1" + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), + }, + }) + handler.Client.Request = clientMock + ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123") + ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024)) + go func() { + time.Sleep(time.Duration(1) * time.Second) + FinishCallback("key") + }() + res, err := handler.Token(ctx, 10, "key") + asserts.NoError(err) + asserts.Equal("123321", res.Policy) + } +} + +func TestDriver_Source(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 失败 + { + res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0) + asserts.Error(err) + asserts.Empty(res) + } + + // 成功 + { + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), + }, + }) + handler.Client.Request = clientMock + handler.Client.Credential.AccessToken = "1" + res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0) + asserts.NoError(err) + asserts.Equal("123321", res) + } +} + +func TestDriver_Thumb(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 失败 + { + ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20}) + ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{}) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + res, err := handler.Thumb(ctx, "123.jpg") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Empty(res.URL) + } + + // 上下文错误 + { + _, err := handler.Thumb(context.Background(), "123.jpg") + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + } +} + +func TestDriver_Delete(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 失败 + { + _, err := handler.Delete(context.Background(), []string{"1"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + } + +} + +func TestDriver_Put(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 失败 + { + err := handler.Put(context.Background(), ioutil.NopCloser(strings.NewReader("")), "dst", 0) + asserts.Error(err) + } +} + +func TestDriver_Get(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + handler.Client, _ = NewClient(&model.Policy{}) + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + + // 无法获取source + { + res, err := handler.Get(context.Background(), "123.txt") + asserts.Error(err) + asserts.Nil(res) + } + + // 成功 + handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), + }, + }) + handler.Client.Request = clientMock + handler.Client.Credential.AccessToken = "1" + + driverClientMock := ClientMock{} + driverClientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`123`)), + }, + }) + handler.HTTPClient = driverClientMock + res, err := handler.Get(context.Background(), "123.txt") + clientMock.AssertExpectations(t) + asserts.NoError(err) + _, err = res.Seek(0, io.SeekEnd) + asserts.NoError(err) + content, err := ioutil.ReadAll(res) + asserts.NoError(err) + asserts.Equal("123", string(content)) +} diff --git a/pkg/filesystem/driver/onedrive/handller.go b/pkg/filesystem/driver/onedrive/handller.go index f0c0516..b04dce3 100644 --- a/pkg/filesystem/driver/onedrive/handller.go +++ b/pkg/filesystem/driver/onedrive/handller.go @@ -26,7 +26,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, ctx, path, url.URL{}, - int64(model.GetIntSetting("preview_timeout", 60)), + 60, false, 0, ) diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index aebaf6d..28dbd9c 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -2,7 +2,6 @@ package onedrive import ( "context" - "encoding/gob" "encoding/json" "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/request" @@ -14,33 +13,6 @@ import ( "time" ) -// oauthEndpoint OAuth接口地址 -type oauthEndpoint struct { - token url.URL - authorize url.URL -} - -// Credential 获取token时返回的凭证 -type Credential struct { - TokenType string `json:"token_type"` - ExpiresIn int64 `json:"expires_in"` - Scope string `json:"scope"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - UserID string `json:"user_id"` -} - -// OAuthError OAuth相关接口的错误响应 -type OAuthError struct { - ErrorType string `json:"error"` - ErrorDescription string `json:"error_description"` - CorrelationID string `json:"correlation_id"` -} - -func init() { - gob.Register(Credential{}) -} - // Error 实现error接口 func (err OAuthError) Error() string { return err.ErrorDescription diff --git a/pkg/filesystem/driver/onedrive/options.go b/pkg/filesystem/driver/onedrive/options.go index dd5d855..5d09ef7 100644 --- a/pkg/filesystem/driver/onedrive/options.go +++ b/pkg/filesystem/driver/onedrive/options.go @@ -38,13 +38,6 @@ func WithConflictBehavior(t string) Option { }) } -// WithExpires 设置过期时间 -func WithExpires(t time.Time) Option { - return optionFunc(func(o *options) { - o.expires = t - }) -} - func (f optionFunc) apply(o *options) { f(o) } diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filesystem/driver/onedrive/types.go index 5730ead..e3e644d 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filesystem/driver/onedrive/types.go @@ -1,7 +1,9 @@ package onedrive import ( + "encoding/gob" "io" + "net/url" "sync" ) @@ -91,6 +93,33 @@ type Chunk struct { Reader io.Reader } +// oauthEndpoint OAuth接口地址 +type oauthEndpoint struct { + token url.URL + authorize url.URL +} + +// Credential 获取token时返回的凭证 +type Credential struct { + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Scope string `json:"scope"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + UserID string `json:"user_id"` +} + +// OAuthError OAuth相关接口的错误响应 +type OAuthError struct { + ErrorType string `json:"error"` + ErrorDescription string `json:"error_description"` + CorrelationID string `json:"correlation_id"` +} + +func init() { + gob.Register(Credential{}) +} + // IsLast 返回是否为最后一个分片 func (chunk *Chunk) IsLast() bool { return chunk.Total-chunk.Offset == chunk.ChunkSize