From 0d2e3cc717162cd2ac77fede8192b87ab272bd83 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 22 Jan 2020 10:50:47 +0800 Subject: [PATCH] Test: Onedrive.OAuth --- pkg/filesystem/driver/onedrive/api.go | 48 ++- pkg/filesystem/driver/onedrive/client.go | 7 +- pkg/filesystem/driver/onedrive/client_test.go | 7 + pkg/filesystem/driver/onedrive/oauth_test.go | 365 ++++++++++++++++++ pkg/filesystem/driver/oss/callback_test.go | 5 - pkg/filesystem/upload.go | 1 - 6 files changed, 405 insertions(+), 28 deletions(-) create mode 100644 pkg/filesystem/driver/onedrive/client_test.go create mode 100644 pkg/filesystem/driver/onedrive/oauth_test.go diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index d3cfba9..09f4eed 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -197,27 +197,33 @@ func (client *Client) Upload(ctx context.Context, dst string, size int, file io. chunkNum++ } for i := 0; i < chunkNum; i++ { - // 分块 - // TODO 取消上传 - chunkSize := int(ChunkSize) - if size-offset < chunkSize { - chunkSize = size - offset - } - chunk := Chunk{ - Offset: offset, - ChunkSize: chunkSize, - Total: size, - Reader: &io.LimitedReader{ - R: file, - N: int64(chunkSize), - }, - } - // 上传 - _, err := client.UploadChunk(ctx, uploadURL, &chunk) - if err != nil { - return err + select { + case <-ctx.Done(): + util.Log().Debug("OneDrive 客户端取消") + return ErrClientCanceled + default: + // 分块 + chunkSize := int(ChunkSize) + if size-offset < chunkSize { + chunkSize = size - offset + } + chunk := Chunk{ + Offset: offset, + ChunkSize: chunkSize, + Total: size, + Reader: &io.LimitedReader{ + R: file, + N: int64(chunkSize), + }, + } + // 上传 + _, err := client.UploadChunk(ctx, uploadURL, &chunk) + if err != nil { + return err + } + offset += chunkSize } - offset += chunkSize + } return nil } @@ -274,7 +280,7 @@ func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error // 取得删除失败的文件 failed := getDeleteFailed(&deleteRes) if len(failed) != 0 { - return failed, errors.New("无法删除文件") + return failed, ErrDeleteFile } return failed, nil } diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index a5ca4ab..ccf1512 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -8,8 +8,13 @@ import ( var ( // ErrAuthEndpoint 无法解析授权端点地址 - ErrAuthEndpoint = errors.New("无法解析授权端点地址") + ErrAuthEndpoint = errors.New("无法解析授权端点地址") + // ErrInvalidRefreshToken 上传策略无有效的RefreshToken ErrInvalidRefreshToken = errors.New("上传策略无有效的RefreshToken") + // ErrDeleteFile 无法删除文件 + ErrDeleteFile = errors.New("无法删除文件") + // ErrClientCanceled 客户端取消操作 + ErrClientCanceled = errors.New("客户端取消操作") ) // Client OneDrive客户端 diff --git a/pkg/filesystem/driver/onedrive/client_test.go b/pkg/filesystem/driver/onedrive/client_test.go new file mode 100644 index 0000000..01ee4db --- /dev/null +++ b/pkg/filesystem/driver/onedrive/client_test.go @@ -0,0 +1,7 @@ +package onedrive + +import "testing" + +func TestNewClient(t *testing.T) { + +} diff --git a/pkg/filesystem/driver/onedrive/oauth_test.go b/pkg/filesystem/driver/onedrive/oauth_test.go new file mode 100644 index 0000000..33857d8 --- /dev/null +++ b/pkg/filesystem/driver/onedrive/oauth_test.go @@ -0,0 +1,365 @@ +package onedrive + +import ( + "context" + "database/sql" + "errors" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/cache" + "github.com/HFO4/cloudreve/pkg/request" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + "time" +) + +var mock sqlmock.Sqlmock + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() +} + +func TestGetOAuthEndpoint(t *testing.T) { + asserts := assert.New(t) + + // URL解析失败 + { + client := Client{ + Endpoints: &Endpoints{ + OAuthURL: string([]byte{0x7f}), + }, + } + res := client.getOAuthEndpoint() + asserts.Nil(res) + } + + { + testCase := []struct { + OAuthURL string + token string + auth string + isChina bool + }{ + { + OAuthURL: "http://login.live.com", + token: "https://login.live.com/oauth20_token.srf", + auth: "https://login.live.com/oauth20_authorize.srf", + isChina: false, + }, + { + OAuthURL: "http://login.chinacloudapi.cn", + token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token", + auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize", + isChina: true, + }, + { + OAuthURL: "other", + token: "https://login.microsoftonline.com/common/oauth2/v2.0/token", + auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + isChina: false, + }, + } + + for i, testCase := range testCase { + client := Client{ + Endpoints: &Endpoints{ + OAuthURL: testCase.OAuthURL, + }, + } + res := client.getOAuthEndpoint() + asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i) + asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i) + asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i) + } + } +} + +func TestClient_OAuthURL(t *testing.T) { + asserts := assert.New(t) + + client := Client{ + ClientID: "client_id", + Redirect: "http://cloudreve.org/callback", + Endpoints: &Endpoints{}, + } + client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() + res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"})) + asserts.NoError(err) + query := res.Query() + asserts.Equal("client_id", query.Get("client_id")) + asserts.Equal("scope1 scope2", query.Get("scope")) + asserts.Equal(client.Redirect, query.Get("redirect_uri")) + +} + +type ClientMock struct { + testMock.Mock +} + +func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { + args := m.Called(method, target, body, opts) + return args.Get(0).(*request.Response) +} + +type mockReader string + +func (r mockReader) Read(b []byte) (int, error) { + return 0, errors.New("read error") +} + +func TestClient_ObtainToken(t *testing.T) { + asserts := assert.New(t) + + client := Client{ + Endpoints: &Endpoints{}, + ClientID: "ClientID", + ClientSecret: "ClientSecret", + Redirect: "Redirect", + } + client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() + + // 刷新Token 成功 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)), + }, + }) + client.Request = clientMock + + res, err := client.ObtainToken(context.Background()) + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotNil(res) + asserts.Equal("i am token", res.AccessToken) + } + + // 重新获取 无法发送请求 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: errors.New("error"), + }) + client.Request = clientMock + + res, err := client.ObtainToken(context.Background(), WithCode("code")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } + + // 刷新Token 无法获取响应正文 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(mockReader("")), + }, + }) + client.Request = clientMock + + res, err := client.ObtainToken(context.Background()) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + asserts.Equal("read error", err.Error()) + } + + // 刷新Token OneDrive返回错误 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)), + }, + }) + client.Request = clientMock + + res, err := client.ObtainToken(context.Background()) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + asserts.Equal("", err.Error()) + } + + // 刷新Token OneDrive未知响应 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + + res, err := client.ObtainToken(context.Background()) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Nil(res) + } +} + +func TestClient_UpdateCredential(t *testing.T) { + asserts := assert.New(t) + client := Client{ + Policy: &model.Policy{Model: gorm.Model{ID: 257}}, + Endpoints: &Endpoints{}, + ClientID: "TestClient_UpdateCredential", + ClientSecret: "ClientSecret", + Redirect: "Redirect", + Credential: &Credential{}, + } + client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() + + // 无有效的RefreshToken + { + err := client.UpdateCredential(context.Background()) + asserts.Equal(ErrInvalidRefreshToken, err) + client.Credential = nil + err = client.UpdateCredential(context.Background()) + asserts.Equal(ErrInvalidRefreshToken, err) + } + + // 成功 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)), + }, + }) + client.Request = clientMock + client.Credential = &Credential{ + RefreshToken: "old_refresh_token", + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err := client.UpdateCredential(context.Background()) + clientMock.AssertExpectations(t) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential") + asserts.True(ok) + cacheCredential := cacheRes.(Credential) + asserts.Equal("new_refresh_token", cacheCredential.RefreshToken) + asserts.Equal("i am token", cacheCredential.AccessToken) + } + + // OneDrive返回错误 + { + cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_") + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + client.Endpoints.OAuthEndpoints.token.String(), + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 400, + Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)), + }, + }) + client.Request = clientMock + client.Credential = &Credential{ + RefreshToken: "old_refresh_token", + } + err := client.UpdateCredential(context.Background()) + clientMock.AssertExpectations(t) + asserts.Error(err) + } + + // 从缓存中获取 + { + cache.Set("onedrive_TestClient_UpdateCredential", Credential{ + ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + }, 0) + client.Credential = &Credential{ + RefreshToken: "old_refresh_token", + } + err := client.UpdateCredential(context.Background()) + asserts.NoError(err) + asserts.Equal("AccessToken", client.Credential.AccessToken) + asserts.Equal("RefreshToken", client.Credential.RefreshToken) + } + + // 无需重新获取 + { + client.Credential = &Credential{ + RefreshToken: "old_refresh_token", + AccessToken: "AccessToken2", + ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), + } + err := client.UpdateCredential(context.Background()) + asserts.NoError(err) + asserts.Equal("AccessToken2", client.Credential.AccessToken) + } +} diff --git a/pkg/filesystem/driver/oss/callback_test.go b/pkg/filesystem/driver/oss/callback_test.go index ca82666..20c8a2c 100644 --- a/pkg/filesystem/driver/oss/callback_test.go +++ b/pkg/filesystem/driver/oss/callback_test.go @@ -185,8 +185,3 @@ C0fTXv+nvlmklvkcolvpvXLTjaxUHR3W9LXxQ2EHXAJfCB+6H2YF1k8CAwEAAQ== asserts.Error(VerifyCallbackSignature(&r)) } } - -///api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH -//{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"} -// aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0= -// e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw== diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index f1be15e..9733987 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -116,7 +116,6 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe } else { reqContext = ctx.Value(fsctx.HTTPCtx).(context.Context) } - defer fs.Recycle() select { case <-reqContext.Done():