From 37926e3133f06e9718a86a6f438fdfb1e4aa40ac Mon Sep 17 00:00:00 2001 From: Aaron Liu <912394456@qq.com> Date: Wed, 24 May 2023 14:39:54 +0800 Subject: [PATCH] feat(policy): add Google Drive Oauth client --- assets | 2 +- models/policy.go | 36 ++-- models/policy_test.go | 6 +- pkg/cluster/controller.go | 10 +- pkg/cluster/controller_test.go | 8 +- pkg/filesystem/driver/googledrive/client.go | 72 ++++++++ pkg/filesystem/driver/googledrive/handler.go | 70 +++++++ pkg/filesystem/driver/googledrive/oauth.go | 174 ++++++++++++++++++ pkg/filesystem/driver/onedrive/client.go | 2 +- pkg/filesystem/driver/onedrive/oauth.go | 11 +- pkg/filesystem/driver/onedrive/oauth_test.go | 4 +- pkg/filesystem/filesystem.go | 4 + .../onedrive/lock.go => oauth/mutex.go} | 2 +- pkg/filesystem/oauth/token.go | 8 + pkg/mocks/controllermock/c.go | 2 +- routers/controllers/admin.go | 18 +- routers/controllers/callback.go | 22 ++- routers/controllers/slave.go | 6 +- routers/router.go | 22 ++- service/admin/policy.go | 41 +++-- service/callback/oauth.go | 56 +++++- service/node/fabric.go | 27 ++- 22 files changed, 524 insertions(+), 79 deletions(-) create mode 100644 pkg/filesystem/driver/googledrive/client.go create mode 100644 pkg/filesystem/driver/googledrive/handler.go create mode 100644 pkg/filesystem/driver/googledrive/oauth.go rename pkg/filesystem/{driver/onedrive/lock.go => oauth/mutex.go} (96%) create mode 100644 pkg/filesystem/oauth/token.go diff --git a/assets b/assets index 88e4b7f..5e66c6f 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit 88e4b7fbf3d5e5806ad0cedce99d845ef21704c7 +Subproject commit 5e66c6fd9c8b50f31a7a2d1f1e77ef13543ee366 diff --git a/models/policy.go b/models/policy.go index bfb0453..896689f 100644 --- a/models/policy.go +++ b/models/policy.go @@ -47,8 +47,8 @@ type PolicyOption struct { FileType []string `json:"file_type"` // MimeType MimeType string `json:"mimetype"` - // OdRedirect Onedrive 重定向地址 - OdRedirect string `json:"od_redirect,omitempty"` + // OauthRedirect Oauth 重定向地址 + OauthRedirect string `json:"od_redirect,omitempty"` // OdProxy Onedrive 反代地址 OdProxy string `json:"od_proxy,omitempty"` // OdDriver OneDrive 驱动器定位符 @@ -155,23 +155,23 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string { fileRule := policy.FileNameRule replaceTable := map[string]string{ - "{randomkey16}": util.RandStringRunes(16), - "{randomkey8}": util.RandStringRunes(8), - "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), - "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), - "{uid}": strconv.Itoa(int(uid)), - "{datetime}": time.Now().Format("20060102150405"), - "{date}": time.Now().Format("20060102"), - "{year}": time.Now().Format("2006"), - "{month}": time.Now().Format("01"), - "{day}": time.Now().Format("02"), - "{hour}": time.Now().Format("15"), - "{minute}": time.Now().Format("04"), - "{second}": time.Now().Format("05"), - "{originname}": origin, - "{ext}": filepath.Ext(origin), + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), + "{uid}": strconv.Itoa(int(uid)), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + "{year}": time.Now().Format("2006"), + "{month}": time.Now().Format("01"), + "{day}": time.Now().Format("02"), + "{hour}": time.Now().Format("15"), + "{minute}": time.Now().Format("04"), + "{second}": time.Now().Format("05"), + "{originname}": origin, + "{ext}": filepath.Ext(origin), "{originname_without_ext}": strings.TrimSuffix(origin, filepath.Ext(origin)), - "{uuid}": uuid.Must(uuid.NewV4()).String(), + "{uuid}": uuid.Must(uuid.NewV4()).String(), } fileRule = util.Replace(replaceTable, fileRule) diff --git a/models/policy_test.go b/models/policy_test.go index 36022e1..f7d4e74 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -25,7 +25,7 @@ func TestGetPolicyByID(t *testing.T) { asserts.NoError(err) asserts.NoError(mock.ExpectationsWereMet()) asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OdRedirect) + asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) rows = sqlmock.NewRows([]string{"name", "type", "options"}) mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) @@ -39,7 +39,7 @@ func TestGetPolicyByID(t *testing.T) { policy, err := GetPolicyByID(uint(22)) asserts.NoError(err) asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OdRedirect) + asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) } @@ -50,7 +50,7 @@ func TestPolicy_BeforeSave(t *testing.T) { testPolicy := Policy{ OptionsSerialized: PolicyOption{ - OdRedirect: "123", + OauthRedirect: "123", }, } expected, _ := json.Marshal(testPolicy.OptionsSerialized) diff --git a/pkg/cluster/controller.go b/pkg/cluster/controller.go index c597d04..85fb178 100644 --- a/pkg/cluster/controller.go +++ b/pkg/cluster/controller.go @@ -35,8 +35,8 @@ type Controller interface { // Get master node info GetMasterInfo(string) (*MasterInfo, error) - // Get master OneDrive policy credential - GetOneDriveToken(string, uint) (string, error) + // Get master Oauth based policy credential + GetPolicyOauthToken(string, uint) (string, error) } type slaveController struct { @@ -181,8 +181,8 @@ func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { return nil, ErrMasterNotFound } -// GetOneDriveToken 获取主机OneDrive凭证 -func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) { +// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证 +func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) { c.lock.RLock() if node, ok := c.masters[id]; ok { @@ -190,7 +190,7 @@ func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, er res, err := node.Client.Request( "GET", - fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID), + fmt.Sprintf("/api/v3/slave/credential/%d", policyID), nil, ).CheckHTTPResponse(200).DecodeResponse() if err != nil { diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go index 305856a..22b25ea 100644 --- a/pkg/cluster/controller_test.go +++ b/pkg/cluster/controller_test.go @@ -320,7 +320,7 @@ func TestSlaveController_GetOneDriveToken(t *testing.T) { // node not exit { - res, err := c.GetOneDriveToken("2", 1) + res, err := c.GetPolicyOauthToken("2", 1) a.Equal(ErrMasterNotFound, err) a.Empty(res) } @@ -336,7 +336,7 @@ func TestSlaveController_GetOneDriveToken(t *testing.T) { "1": {Client: mockRequest}, }, } - res, err := c.GetOneDriveToken("1", 1) + res, err := c.GetPolicyOauthToken("1", 1) a.Error(err) a.Empty(res) mockRequest.AssertExpectations(t) @@ -356,7 +356,7 @@ func TestSlaveController_GetOneDriveToken(t *testing.T) { "1": {Client: mockRequest}, }, } - res, err := c.GetOneDriveToken("1", 1) + res, err := c.GetPolicyOauthToken("1", 1) a.Equal(1, err.(serializer.AppError).Code) a.Empty(res) mockRequest.AssertExpectations(t) @@ -376,7 +376,7 @@ func TestSlaveController_GetOneDriveToken(t *testing.T) { "1": {Client: mockRequest}, }, } - res, err := c.GetOneDriveToken("1", 1) + res, err := c.GetPolicyOauthToken("1", 1) a.NoError(err) a.Equal("expected", res) mockRequest.AssertExpectations(t) diff --git a/pkg/filesystem/driver/googledrive/client.go b/pkg/filesystem/driver/googledrive/client.go new file mode 100644 index 0000000..c4b0305 --- /dev/null +++ b/pkg/filesystem/driver/googledrive/client.go @@ -0,0 +1,72 @@ +package googledrive + +import ( + "errors" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/request" +) + +// Client Google Drive client +type Client struct { + Endpoints *Endpoints + Policy *model.Policy + Credential *Credential + + ClientID string + ClientSecret string + Redirect string + + Request request.Client + ClusterController cluster.Controller +} + +// Endpoints OneDrive客户端相关设置 +type Endpoints struct { + UserConsentEndpoint string // OAuth认证的基URL + TokenEndpoint string // OAuth token 基URL + EndpointURL string // 接口请求的基URL +} + +const ( + TokenCachePrefix = "googledrive_" + + oauthEndpoint = "https://oauth2.googleapis.com/token" + userConsentBase = "https://accounts.google.com/o/oauth2/auth" + v3DriveEndpoint = "https://www.googleapis.com/drive/v3" +) + +var ( + // Defualt required scopes + RequiredScope = []string{ + "https://www.googleapis.com/auth/drive", + "openid", + "profile", + "https://www.googleapis.com/auth/userinfo.profile", + } + + // ErrInvalidRefreshToken 上传策略无有效的RefreshToken + ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy") +) + +// NewClient 根据存储策略获取新的client +func NewClient(policy *model.Policy) (*Client, error) { + client := &Client{ + Endpoints: &Endpoints{ + TokenEndpoint: oauthEndpoint, + UserConsentEndpoint: userConsentBase, + EndpointURL: v3DriveEndpoint, + }, + Credential: &Credential{ + RefreshToken: policy.AccessKey, + }, + Policy: policy, + ClientID: policy.BucketName, + ClientSecret: policy.SecretKey, + Redirect: policy.OptionsSerialized.OauthRedirect, + Request: request.NewClient(), + ClusterController: cluster.DefaultController, + } + + return client, nil +} diff --git a/pkg/filesystem/driver/googledrive/handler.go b/pkg/filesystem/driver/googledrive/handler.go new file mode 100644 index 0000000..08a3855 --- /dev/null +++ b/pkg/filesystem/driver/googledrive/handler.go @@ -0,0 +1,70 @@ +package googledrive + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "net/url" +) + +// Driver Google Drive 适配器 +type Driver struct { + Policy *model.Policy + HTTPClient request.Client +} + +// http://localhost:3000/api/v3/callback/googledrive/auth?code=4/0AVHEtk4AaNbo5YoCrMSGgoJfZfe6SgEOVmA7XtalZl8BMtdsAIRWqxt6jO4NKJCxGVxyQA&scope=profile%20openid%20https://www.googleapis.com/auth/userinfo.profile%20https://www.googleapis.com/auth/drive&authuser=0&prompt=consent +// https://accounts.google.com/o/oauth2/v2/auth?client_id=89866991293-5uja7qsbl8btuak3hb41h3o8u9jhlckg.apps.googleusercontent.com&response_type=code&redirect_uri=http://localhost:3000/api/v3/callback/googledrive/auth&scope=openid+profile+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&access_type=offline&prompt=consent +//https://accounts.google.com/o/oauth2/auth?client_id=202264815644.apps.googleusercontent.com&response_type=code&redirect_uri=http%3A%2F%2F127.0.0.1%3A53682%2F&scope=openid+profile+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fphotoslibrary&access_type=offline&prompt=consent&state=MjAyMjY0ODE1NjQ0LmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29tOjpYNFozY2E4eGZXRGIxVm9vLUY5YTdaeEo6Omh0dHA6Ly8xMjcuMC4wLjE6NTM2ODIv + +// NewDriver 从存储策略初始化新的Driver实例 +func NewDriver(policy *model.Policy) (driver.Handler, error) { + return &Driver{ + Policy: policy, + HTTPClient: request.NewClient(), + }, nil +} + +func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { + //TODO implement me + panic("implement me") +} + +func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { + //TODO implement me + panic("implement me") +} + +func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { + //TODO implement me + panic("implement me") +} + +func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { + //TODO implement me + panic("implement me") +} + +func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { + //TODO implement me + panic("implement me") +} + +func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { + //TODO implement me + panic("implement me") +} + +func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { + //TODO implement me + panic("implement me") +} + +func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { + //TODO implement me + panic("implement me") +} diff --git a/pkg/filesystem/driver/googledrive/oauth.go b/pkg/filesystem/driver/googledrive/oauth.go new file mode 100644 index 0000000..afb3460 --- /dev/null +++ b/pkg/filesystem/driver/googledrive/oauth.go @@ -0,0 +1,174 @@ +package googledrive + +import ( + "context" + "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// Credential 获取token时返回的凭证 +type Credential struct { + ExpiresIn int64 `json:"expires_in"` + Scope string `json:"scope"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + UserID string `json:"user_id"` +} + +// OAuthError OAuth相关接口的错误响应 +type OAuthError struct { + ErrorType string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +// Error 实现error接口 +func (err OAuthError) Error() string { + return err.ErrorDescription +} + +// OAuthURL 获取OAuth认证页面URL +func (client *Client) OAuthURL(ctx context.Context, scope []string) string { + query := url.Values{ + "client_id": {client.ClientID}, + "scope": {strings.Join(scope, " ")}, + "response_type": {"code"}, + "redirect_uri": {client.Redirect}, + "access_type": {"offline"}, + "prompt": {"consent"}, + } + + u, _ := url.Parse(client.Endpoints.UserConsentEndpoint) + u.RawQuery = query.Encode() + return u.String() +} + +// ObtainToken 通过code或refresh_token兑换token +func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) { + body := url.Values{ + "client_id": {client.ClientID}, + "redirect_uri": {client.Redirect}, + "client_secret": {client.ClientSecret}, + } + if code != "" { + body.Add("grant_type", "authorization_code") + body.Add("code", code) + } else { + body.Add("grant_type", "refresh_token") + body.Add("refresh_token", refreshToken) + } + strBody := body.Encode() + + res := client.Request.Request( + "POST", + client.Endpoints.TokenEndpoint, + io.NopCloser(strings.NewReader(strBody)), + request.WithHeader(http.Header{ + "Content-Type": {"application/x-www-form-urlencoded"}}, + ), + request.WithContentLength(int64(len(strBody))), + ) + if res.Err != nil { + return nil, res.Err + } + + respBody, err := res.GetResponse() + if err != nil { + return nil, err + } + + var ( + errResp OAuthError + credential Credential + decodeErr error + ) + + if res.Response.StatusCode != 200 { + decodeErr = json.Unmarshal([]byte(respBody), &errResp) + } else { + decodeErr = json.Unmarshal([]byte(respBody), &credential) + } + if decodeErr != nil { + return nil, decodeErr + } + + if errResp.ErrorType != "" { + return nil, errResp + } + + return &credential, nil +} + +// UpdateCredential 更新凭证,并检查有效期 +func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error { + if isSlave { + return client.fetchCredentialFromMaster(ctx) + } + + oauth.GlobalMutex.Lock(client.Policy.ID) + defer oauth.GlobalMutex.Unlock(client.Policy.ID) + + // 如果已存在凭证 + if client.Credential != nil && client.Credential.AccessToken != "" { + // 检查已有凭证是否过期 + if client.Credential.ExpiresIn > time.Now().Unix() { + // 未过期,不要更新 + return nil + } + } + + // 尝试从缓存中获取凭证 + if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok { + credential := cacheCredential.(Credential) + if credential.ExpiresIn > time.Now().Unix() { + client.Credential = &credential + return nil + } + } + + // 获取新的凭证 + if client.Credential == nil || client.Credential.RefreshToken == "" { + // 无有效的RefreshToken + util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name) + return ErrInvalidRefreshToken + } + + credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken) + if err != nil { + return err + } + + // 更新有效期为绝对时间戳 + expires := credential.ExpiresIn - 60 + credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix() + // refresh token for Google Drive does not expire in production + credential.RefreshToken = client.Credential.RefreshToken + client.Credential = credential + + // 更新缓存 + cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires)) + + return nil +} + +func (client *Client) AccessToken() string { + return client.Credential.AccessToken +} + +// UpdateCredential 更新凭证,并检查有效期 +func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { + res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) + if err != nil { + return err + } + + client.Credential = &Credential{AccessToken: res} + return nil +} diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index 39beff3..957af8e 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -58,7 +58,7 @@ func NewClient(policy *model.Policy) (*Client, error) { Policy: policy, ClientID: policy.BucketName, ClientSecret: policy.SecretKey, - Redirect: policy.OptionsSerialized.OdRedirect, + Redirect: policy.OptionsSerialized.OauthRedirect, Request: request.NewClient(), ClusterController: cluster.DefaultController, } diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 914a498..bb00005 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -10,6 +10,7 @@ import ( "time" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -128,8 +129,8 @@ func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error return client.fetchCredentialFromMaster(ctx) } - GlobalMutex.Lock(client.Policy.ID) - defer GlobalMutex.Unlock(client.Policy.ID) + oauth.GlobalMutex.Lock(client.Policy.ID) + defer oauth.GlobalMutex.Unlock(client.Policy.ID) // 如果已存在凭证 if client.Credential != nil && client.Credential.AccessToken != "" { @@ -175,9 +176,13 @@ func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error return nil } +func (client *Client) AccessToken() string { + return client.Credential.AccessToken +} + // UpdateCredential 更新凭证,并检查有效期 func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { - res, err := client.ClusterController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) + res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) if err != nil { return err } diff --git a/pkg/filesystem/driver/onedrive/oauth_test.go b/pkg/filesystem/driver/onedrive/oauth_test.go index 61c5e75..b2525b7 100644 --- a/pkg/filesystem/driver/onedrive/oauth_test.go +++ b/pkg/filesystem/driver/onedrive/oauth_test.go @@ -368,7 +368,7 @@ func TestClient_UpdateCredential(t *testing.T) { // slave failed { mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetOneDriveToken", testMock.Anything, testMock.Anything).Return("", errors.New("error")) + mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error")) client.ClusterController = mockController err := client.UpdateCredential(context.Background(), true) asserts.Error(err) @@ -377,7 +377,7 @@ func TestClient_UpdateCredential(t *testing.T) { // slave success { mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetOneDriveToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil) + mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil) client.ClusterController = mockController err := client.UpdateCredential(context.Background(), true) asserts.NoError(err) diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index f1745b6..c7316e6 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -176,6 +176,10 @@ func (fs *FileSystem) DispatchHandler() error { handler, err := s3.NewDriver(currentPolicy) fs.Handler = handler return err + case "googledrive": + handler, err := googledrive.NewDriver(policy) + fs.Handler = handler + return err default: return ErrUnknownPolicyType } diff --git a/pkg/filesystem/driver/onedrive/lock.go b/pkg/filesystem/oauth/mutex.go similarity index 96% rename from pkg/filesystem/driver/onedrive/lock.go rename to pkg/filesystem/oauth/mutex.go index 655936b..41f588d 100644 --- a/pkg/filesystem/driver/onedrive/lock.go +++ b/pkg/filesystem/oauth/mutex.go @@ -1,4 +1,4 @@ -package onedrive +package oauth import "sync" diff --git a/pkg/filesystem/oauth/token.go b/pkg/filesystem/oauth/token.go new file mode 100644 index 0000000..cdc5cf0 --- /dev/null +++ b/pkg/filesystem/oauth/token.go @@ -0,0 +1,8 @@ +package oauth + +import "context" + +type TokenProvider interface { + UpdateCredential(ctx context.Context, isSlave bool) error + AccessToken() string +} diff --git a/pkg/mocks/controllermock/c.go b/pkg/mocks/controllermock/c.go index a2890b2..6a77793 100644 --- a/pkg/mocks/controllermock/c.go +++ b/pkg/mocks/controllermock/c.go @@ -37,7 +37,7 @@ func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, erro return args.Get(0).(*cluster.MasterInfo), args.Error(1) } -func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) { +func (s SlaveControllerMock) GetPolicyOauthToken(s2 string, u uint) (string, error) { args := s.Called(s2, u) return args.String(0), args.Error(1) } diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index 2412777..26d917e 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -192,14 +192,16 @@ func AdminAddSCF(c *gin.Context) { } } -// AdminOneDriveOAuth 获取 OneDrive OAuth URL -func AdminOneDriveOAuth(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindUri(&service); err == nil { - res := service.GetOAuth(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminOAuthURL 获取 OneDrive OAuth URL +func AdminOAuthURL(policyType string) gin.HandlerFunc { + return func(c *gin.Context) { + var service admin.PolicyService + if err := c.ShouldBindUri(&service); err == nil { + res := service.GetOAuth(c, policyType) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } } } diff --git a/routers/controllers/callback.go b/routers/controllers/callback.go index ba3ab46..fec5d07 100644 --- a/routers/controllers/callback.go +++ b/routers/controllers/callback.go @@ -83,9 +83,27 @@ func OneDriveCallback(c *gin.Context) { // OneDriveOAuth OneDrive 授权回调 func OneDriveOAuth(c *gin.Context) { - var callbackBody callback.OneDriveOauthService + var callbackBody callback.OauthService if err := c.ShouldBindQuery(&callbackBody); err == nil { - res := callbackBody.Auth(c) + res := callbackBody.OdAuth(c) + redirect := model.GetSiteURL() + redirect.Path = path.Join(redirect.Path, "/admin/policy") + queries := redirect.Query() + queries.Add("code", strconv.Itoa(res.Code)) + queries.Add("msg", res.Msg) + queries.Add("err", res.Error) + redirect.RawQuery = queries.Encode() + c.Redirect(303, redirect.String()) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// GoogleDriveOAuth Google Drive 授权回调 +func GoogleDriveOAuth(c *gin.Context) { + var callbackBody callback.OauthService + if err := c.ShouldBindQuery(&callbackBody); err == nil { + res := callbackBody.GDriveAuth(c) redirect := model.GetSiteURL() redirect.Path = path.Join(redirect.Path, "/admin/policy") queries := redirect.Query() diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e1e7de2..2df3698 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -223,9 +223,9 @@ func SlaveNotificationPush(c *gin.Context) { } } -// SlaveGetOneDriveCredential 从机获取主机的OneDrive存储策略凭证 -func SlaveGetOneDriveCredential(c *gin.Context) { - var service node.OneDriveCredentialService +// SlaveGetOauthCredential 从机获取主机的OneDrive存储策略凭证 +func SlaveGetOauthCredential(c *gin.Context) { + var service node.OauthCredentialService if err := c.ShouldBindUri(&service); err == nil { res := service.Get(c) c.JSON(200, res) diff --git a/routers/router.go b/routers/router.go index d5f6b43..4e16b77 100644 --- a/routers/router.go +++ b/routers/router.go @@ -260,8 +260,8 @@ func InitMasterRouter() *gin.Engine { // 删除上传会话 upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession) } - // OneDrive 存储策略凭证 - slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential) + // Oauth 存储策略凭证 + slave.GET("credential/:id", controllers.SlaveGetOauthCredential) } // 回调接口 @@ -310,6 +310,15 @@ func InitMasterRouter() *gin.Engine { controllers.OneDriveOAuth, ) } + // Google Drive related + gdrive := callback.Group("googledrive") + { + // OAuth 完成 + gdrive.GET( + "auth", + controllers.GoogleDriveOAuth, + ) + } // 腾讯云COS策略上传回调 callback.GET( "cos/:sessionID", @@ -454,7 +463,14 @@ func InitMasterRouter() *gin.Engine { // 创建COS回调函数 policy.POST("scf", controllers.AdminAddSCF) // 获取 OneDrive OAuth URL - policy.GET(":id/oauth", controllers.AdminOneDriveOAuth) + oauth := policy.Group(":id/oauth") + { + // 获取 OneDrive OAuth URL + oauth.GET("onedrive", controllers.AdminOAuthURL("onedrive")) + // 获取 Google Drive OAuth URL + oauth.GET("googledrive", controllers.AdminOAuthURL("googledrive")) + } + // 获取 存储策略 policy.GET(":id", controllers.AdminGetPolicy) // 删除 存储策略 diff --git a/service/admin/policy.go b/service/admin/policy.go index abfc9da..478203a 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" "net/http" "net/url" "os" @@ -104,27 +105,41 @@ func (service *PolicyService) Get() serializer.Response { } // GetOAuth 获取 OneDrive OAuth 地址 -func (service *PolicyService) GetOAuth(c *gin.Context) serializer.Response { +func (service *PolicyService) GetOAuth(c *gin.Context, policyType string) serializer.Response { policy, err := model.GetPolicyByID(service.ID) - if err != nil || policy.Type != "onedrive" { + if err != nil || policy.Type != policyType { return serializer.Err(serializer.CodePolicyNotExist, "", nil) } - client, err := onedrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) - } - util.SetSession(c, map[string]interface{}{ - "onedrive_oauth_policy": policy.ID, + policyType + "_oauth_policy": policy.ID, }) - cache.Deletes([]string{policy.BucketName}, "onedrive_") + var redirect string + switch policy.Type { + case "onedrive": + client, err := onedrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) + } + + redirect = client.OAuthURL(context.Background(), []string{ + "offline_access", + "files.readwrite.all", + }) + case "googledrive": + client, err := googledrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) + } + + redirect = client.OAuthURL(context.Background(), googledrive.RequiredScope) + } + + // Delete token cache + cache.Deletes([]string{policy.BucketName}, policyType+"_") - return serializer.Response{Data: client.OAuthURL(context.Background(), []string{ - "offline_access", - "files.readwrite.all", - })} + return serializer.Response{Data: redirect} } // AddSCF 创建回调云函数 diff --git a/service/callback/oauth.go b/service/callback/oauth.go index 2494982..f93636b 100644 --- a/service/callback/oauth.go +++ b/service/callback/oauth.go @@ -6,22 +6,70 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" + "github.com/samber/lo" "strings" ) -// OneDriveOauthService OneDrive 授权回调服务 -type OneDriveOauthService struct { +// OauthService OAuth 存储策略授权回调服务 +type OauthService struct { Code string `form:"code"` Error string `form:"error"` ErrorMsg string `form:"error_description"` + Scope string `form:"scope"` } -// Auth 更新认证信息 -func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response { +// GDriveAuth Google Drive 更新认证信息 +func (service *OauthService) GDriveAuth(c *gin.Context) serializer.Response { + if service.Error != "" { + return serializer.ParamErr(service.Error, nil) + } + + // validate required scope + if missing, found := lo.Find[string](googledrive.RequiredScope, func(item string) bool { + return !strings.Contains(service.Scope, item) + }); found { + return serializer.ParamErr(fmt.Sprintf("Missing required scope: %s", missing), nil) + } + + policyID, ok := util.GetSession(c, "googledrive_oauth_policy").(uint) + if !ok { + return serializer.Err(serializer.CodeNotFound, "", nil) + } + + util.DeleteSession(c, "googledrive_oauth_policy") + + policy, err := model.GetPolicyByID(policyID) + if err != nil { + return serializer.Err(serializer.CodePolicyNotExist, "", nil) + } + + client, err := googledrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) + } + + credential, err := client.ObtainToken(c, service.Code, "") + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) + } + + // 更新存储策略的 RefreshToken + client.Policy.AccessKey = credential.RefreshToken + if err := client.Policy.SaveAndClearCache(); err != nil { + return serializer.DBErr("Failed to update RefreshToken", err) + } + + cache.Deletes([]string{client.Policy.AccessKey}, googledrive.TokenCachePrefix) + return serializer.Response{} +} + +// OdAuth OneDrive 更新认证信息 +func (service *OauthService) OdAuth(c *gin.Context) serializer.Response { if service.Error != "" { return serializer.ParamErr(service.ErrorMsg, nil) } diff --git a/service/node/fabric.go b/service/node/fabric.go index a1b6212..deb2184 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -5,7 +5,9 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" @@ -15,7 +17,7 @@ type SlaveNotificationService struct { Subject string `uri:"subject" binding:"required"` } -type OneDriveCredentialService struct { +type OauthCredentialService struct { PolicyID uint `uri:"id" binding:"required"` } @@ -43,21 +45,32 @@ func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) s return serializer.Response{} } -// Get 获取主机OneDrive策略的AccessToken -func (s *OneDriveCredentialService) Get(c *gin.Context) serializer.Response { +// Get 获取主机Oauth策略的AccessToken +func (s *OauthCredentialService) Get(c *gin.Context) serializer.Response { policy, err := model.GetPolicyByID(s.PolicyID) if err != nil { return serializer.Err(serializer.CodePolicyNotExist, "", err) } - client, err := onedrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err) + var client oauth.TokenProvider + switch policy.Type { + case "onedrive": + client, err = onedrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err) + } + case "googledrive": + client, err = googledrive.NewClient(&policy) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize Google Drive client", err) + } + default: + return serializer.Err(serializer.CodePolicyNotExist, "", nil) } if err := client.UpdateCredential(c, conf.SystemConfig.Mode == "slave"); err != nil { return serializer.Err(serializer.CodeInternalSetting, "Cannot refresh OneDrive credential", err) } - return serializer.Response{Data: client.Credential.AccessToken} + return serializer.Response{Data: client.AccessToken()} }