diff --git a/models/policy.go b/models/policy.go index 52f9668..dfd068d 100644 --- a/models/policy.go +++ b/models/policy.go @@ -51,6 +51,8 @@ type PolicyOption struct { OdRedirect string `json:"od_redirect,omitempty"` // OdProxy Onedrive 反代地址 OdProxy string `json:"od_proxy,omitempty"` + // OdDriver OneDrive 驱动器定位符 + OdDriver string `json:"od_driver,omitempty"` // Region 区域代码 Region string `json:"region,omitempty"` // ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段 @@ -268,9 +270,8 @@ func (policy *Policy) GetUploadURL() string { return server.ResolveReference(controller).String() } -// UpdateAccessKey 更新 AccessKey -func (policy *Policy) UpdateAccessKey(key string) error { - policy.AccessKey = key +// SaveAndClearCache 更新并清理缓存 +func (policy *Policy) SaveAndClearCache() error { err := DB.Save(policy).Error policy.ClearCache() return err diff --git a/models/policy_test.go b/models/policy_test.go index 6168148..91de270 100644 --- a/models/policy_test.go +++ b/models/policy_test.go @@ -257,7 +257,8 @@ func TestPolicy_UpdateAccessKey(t *testing.T) { mock.ExpectBegin() mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() - err := policy.UpdateAccessKey("123") + policy.AccessKey = "123" + err := policy.SaveAndClearCache() asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) } diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 0a469ea..6dfb6ca 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -53,12 +53,23 @@ func (err RespError) Error() string { return err.APIError.Message } -func (client *Client) getRequestURL(api string) string { +func (client *Client) getRequestURL(api string, opts ...Option) string { + options := newDefaultOption() + for _, o := range opts { + o.apply(options) + } + base, _ := url.Parse(client.Endpoints.EndpointURL) if base == nil { return "" } - base.Path = path.Join(base.Path, api) + + if options.useDriverResource { + base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api) + } else { + base.Path = path.Join(base.Path, api) + } + return base.String() } @@ -67,9 +78,9 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo var requestURL string dst := strings.TrimPrefix(path, "/") if dst == "" { - requestURL = client.getRequestURL("me/drive/root/children") + requestURL = client.getRequestURL("root/children") } else { - requestURL = client.getRequestURL("me/drive/root:/" + dst + ":/children") + requestURL = client.getRequestURL("root:/" + dst + ":/children") } res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200) @@ -103,10 +114,10 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) { var requestURL string if id != "" { - requestURL = client.getRequestURL("/me/drive/items/" + id) + requestURL = client.getRequestURL("items/" + id) } else { dst := strings.TrimPrefix(path, "/") - requestURL = client.getRequestURL("me/drive/root:/" + dst) + requestURL = client.getRequestURL("root:/" + dst) } res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200) @@ -129,14 +140,13 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn // CreateUploadSession 创建分片上传会话 func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) { - options := newDefaultOption() for _, o := range opts { o.apply(options) } dst = strings.TrimPrefix(dst, "/") - requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/createUploadSession") + requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession") body := map[string]map[string]interface{}{ "item": { "@microsoft.graph.conflictBehavior": options.conflictBehavior, @@ -161,6 +171,33 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts return uploadSession.UploadURL, nil } +// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID +func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) { + siteUrlParsed, err := url.Parse(siteUrl) + if err != nil { + return "", err + } + + hostName := siteUrlParsed.Hostname() + relativePath := strings.Trim(siteUrlParsed.Path, "/") + requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false)) + res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200) + if reqErr != nil { + return "", reqErr + } + + var ( + decodeErr error + siteInfo Site + ) + decodeErr = json.Unmarshal([]byte(res), &siteInfo) + if decodeErr != nil { + return "", decodeErr + } + + return siteInfo.ID, nil +} + // GetUploadSessionStatus 查询上传会话状态 func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) { res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200) @@ -300,7 +337,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read } dst = strings.TrimPrefix(dst, "/") - requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/content") + requestURL := client.getRequestURL("root:/" + dst + ":/content") requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior) res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)), @@ -357,7 +394,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, // 由于API限制,最多删除20个 func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) { body := client.makeBatchDeleteRequestsBody(dst) - res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch"), body, 200) + res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch", + WithDriverResource(false)), body, 200) if err != nil { return dst, err } @@ -396,7 +434,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string { } for i, v := range files { v = strings.TrimPrefix(v, "/") - filePath, _ := url.Parse("/me/drive/root:/") + filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/") filePath.Path = path.Join(filePath.Path, v) req.Requests[i] = BatchRequest{ ID: v, @@ -418,10 +456,10 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s ) if client.Endpoints.isInChina { cropOption = "large" - requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails/0") + "/" + cropOption + requestURL = client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/" + cropOption } else { cropOption = fmt.Sprintf("c%dx%d_Crop", w, h) - requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails") + "?select=" + cropOption + requestURL = client.getRequestURL("root:/"+dst+":/thumbnails") + "?select=" + cropOption } res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200) diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go index 1f89576..b0324ea 100644 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -167,6 +167,82 @@ func TestClient_GetRequestURL(t *testing.T) { client.Endpoints.EndpointURL = string([]byte{0x7f}) asserts.Equal("", client.getRequestURL("123")) } + + // 使用DriverResource + { + client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" + asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123")) + } + + // 不使用DriverResource + { + client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" + asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false))) + } +} + +func TestClient_GetSiteIDByURL(t *testing.T) { + asserts := assert.New(t) + client, _ := NewClient(&model.Policy{}) + client.Credential.AccessToken = "AccessToken" + + // 请求失败 + { + client.Credential.ExpiresIn = 0 + res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") + asserts.Error(err) + asserts.Empty(res) + + } + + // 返回未知响应 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`???`)), + }, + }) + client.Request = clientMock + res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + } + + // 返回正常 + { + client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() + clientMock := ClientMock{} + clientMock.On( + "Request", + "GET", + testMock.Anything, + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)), + }, + }) + client.Request = clientMock + res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") + clientMock.AssertExpectations(t) + asserts.NoError(err) + asserts.NotEmpty(res) + asserts.Equal("123321", res) + } } func TestClient_Meta(t *testing.T) { diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index 2767fe2..101f9c3 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -37,14 +37,16 @@ type Endpoints struct { OAuthEndpoints *oauthEndpoint EndpointURL string // 接口请求的基URL isInChina bool // 是否为世纪互联 + DriverResource string // 要使用的驱动器 } // NewClient 根据存储策略获取新的client func NewClient(policy *model.Policy) (*Client, error) { client := &Client{ Endpoints: &Endpoints{ - OAuthURL: policy.BaseURL, - EndpointURL: policy.Server, + OAuthURL: policy.BaseURL, + EndpointURL: policy.Server, + DriverResource: policy.OptionsSerialized.OdDriver, }, Credential: &Credential{ RefreshToken: policy.AccessKey, @@ -56,6 +58,10 @@ func NewClient(policy *model.Policy) (*Client, error) { Request: request.HTTPClient{}, } + if client.Endpoints.DriverResource == "" { + client.Endpoints.DriverResource = "me/drive" + } + oauthBase := client.getOAuthEndpoint() if oauthBase == nil { return nil, ErrAuthEndpoint diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 607accd..9b33d7a 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -160,7 +160,8 @@ func (client *Client) UpdateCredential(ctx context.Context) error { client.Credential = credential // 更新存储策略的 RefreshToken - client.Policy.UpdateAccessKey(credential.RefreshToken) + client.Policy.AccessKey = credential.RefreshToken + client.Policy.SaveAndClearCache() // 更新缓存 cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) diff --git a/pkg/filesystem/driver/onedrive/options.go b/pkg/filesystem/driver/onedrive/options.go index 5d09ef7..0c8c107 100644 --- a/pkg/filesystem/driver/onedrive/options.go +++ b/pkg/filesystem/driver/onedrive/options.go @@ -8,11 +8,12 @@ type Option interface { } type options struct { - redirect string - code string - refreshToken string - conflictBehavior string - expires time.Time + redirect string + code string + refreshToken string + conflictBehavior string + expires time.Time + useDriverResource bool } type optionFunc func(*options) @@ -38,13 +39,21 @@ func WithConflictBehavior(t string) Option { }) } +// WithConflictBehavior 设置文件重名后的处理方式 +func WithDriverResource(t bool) Option { + return optionFunc(func(o *options) { + o.useDriverResource = t + }) +} + func (f optionFunc) apply(o *options) { f(o) } func newDefaultOption() *options { return &options{ - conflictBehavior: "fail", - expires: time.Now().UTC().Add(time.Duration(1) * time.Hour), + conflictBehavior: "fail", + useDriverResource: true, + expires: time.Now().UTC().Add(time.Duration(1) * time.Hour), } } diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filesystem/driver/onedrive/types.go index fdf64d5..9dc14fa 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filesystem/driver/onedrive/types.go @@ -131,6 +131,15 @@ type OAuthError struct { CorrelationID string `json:"correlation_id"` } +// Site SharePoint 站点信息 +type Site struct { + Description string `json:"description"` + ID string `json:"id"` + Name string `json:"name"` + DisplayName string `json:"displayName"` + WebUrl string `json:"webUrl"` +} + func init() { gob.Register(Credential{}) } diff --git a/service/callback/oauth.go b/service/callback/oauth.go index 1515f99..cbbe6d2 100644 --- a/service/callback/oauth.go +++ b/service/callback/oauth.go @@ -2,6 +2,7 @@ package callback import ( "context" + "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" @@ -41,17 +42,42 @@ func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err) } - credential, err := client.ObtainToken(context.Background(), onedrive.WithCode(service.Code)) + credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code)) if err != nil { return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err) } // 更新存储策略的 RefreshToken - if err := client.Policy.UpdateAccessKey(credential.RefreshToken); err != nil { + client.Policy.AccessKey = credential.RefreshToken + if err := client.Policy.SaveAndClearCache(); err != nil { return serializer.DBErr("无法更新 RefreshToken", err) } cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") + if client.Policy.OptionsSerialized.OdDriver != "" { + if err := querySharePointSiteID(c, client.Policy); err != nil { + return serializer.Err(serializer.CodeInternalSetting, "无法查询 SharePoint 站点 ID", err) + } + } return serializer.Response{} } + +func querySharePointSiteID(ctx context.Context, policy *model.Policy) error { + client, err := onedrive.NewClient(policy) + if err != nil { + return err + } + + id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver) + if err != nil { + return err + } + + client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id) + if err := client.Policy.SaveAndClearCache(); err != nil { + return err + } + + return nil +}