Feat: support using SharePoint site to store files

pull/796/head
HFO4 4 years ago
parent a54acd71c2
commit 4e2f243436

@ -51,6 +51,8 @@ type PolicyOption struct {
OdRedirect string `json:"od_redirect,omitempty"` OdRedirect string `json:"od_redirect,omitempty"`
// OdProxy Onedrive 反代地址 // OdProxy Onedrive 反代地址
OdProxy string `json:"od_proxy,omitempty"` OdProxy string `json:"od_proxy,omitempty"`
// OdDriver OneDrive 驱动器定位符
OdDriver string `json:"od_driver,omitempty"`
// Region 区域代码 // Region 区域代码
Region string `json:"region,omitempty"` Region string `json:"region,omitempty"`
// ServerSideEndpoint 服务端请求使用的 Endpoint为空时使用 Policy.Server 字段 // ServerSideEndpoint 服务端请求使用的 Endpoint为空时使用 Policy.Server 字段
@ -268,9 +270,8 @@ func (policy *Policy) GetUploadURL() string {
return server.ResolveReference(controller).String() return server.ResolveReference(controller).String()
} }
// UpdateAccessKey 更新 AccessKey // SaveAndClearCache 更新并清理缓存
func (policy *Policy) UpdateAccessKey(key string) error { func (policy *Policy) SaveAndClearCache() error {
policy.AccessKey = key
err := DB.Save(policy).Error err := DB.Save(policy).Error
policy.ClearCache() policy.ClearCache()
return err return err

@ -257,7 +257,8 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() mock.ExpectCommit()
err := policy.UpdateAccessKey("123") policy.AccessKey = "123"
err := policy.SaveAndClearCache()
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err) asserts.NoError(err)
} }

@ -53,12 +53,23 @@ func (err RespError) Error() string {
return err.APIError.Message return err.APIError.Message
} }
func (client *Client) getRequestURL(api string) string { func (client *Client) getRequestURL(api string, opts ...Option) string {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
base, _ := url.Parse(client.Endpoints.EndpointURL) base, _ := url.Parse(client.Endpoints.EndpointURL)
if base == nil { if base == nil {
return "" return ""
} }
if options.useDriverResource {
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
} else {
base.Path = path.Join(base.Path, api) base.Path = path.Join(base.Path, api)
}
return base.String() return base.String()
} }
@ -67,9 +78,9 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
var requestURL string var requestURL string
dst := strings.TrimPrefix(path, "/") dst := strings.TrimPrefix(path, "/")
if dst == "" { if dst == "" {
requestURL = client.getRequestURL("me/drive/root/children") requestURL = client.getRequestURL("root/children")
} else { } else {
requestURL = client.getRequestURL("me/drive/root:/" + dst + ":/children") requestURL = client.getRequestURL("root:/" + dst + ":/children")
} }
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200) res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
@ -103,10 +114,10 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) { func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
var requestURL string var requestURL string
if id != "" { if id != "" {
requestURL = client.getRequestURL("/me/drive/items/" + id) requestURL = client.getRequestURL("items/" + id)
} else { } else {
dst := strings.TrimPrefix(path, "/") dst := strings.TrimPrefix(path, "/")
requestURL = client.getRequestURL("me/drive/root:/" + dst) requestURL = client.getRequestURL("root:/" + dst)
} }
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200) res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
@ -129,14 +140,13 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn
// CreateUploadSession 创建分片上传会话 // CreateUploadSession 创建分片上传会话
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) { func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
options := newDefaultOption() options := newDefaultOption()
for _, o := range opts { for _, o := range opts {
o.apply(options) o.apply(options)
} }
dst = strings.TrimPrefix(dst, "/") dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/createUploadSession") requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
body := map[string]map[string]interface{}{ body := map[string]map[string]interface{}{
"item": { "item": {
"@microsoft.graph.conflictBehavior": options.conflictBehavior, "@microsoft.graph.conflictBehavior": options.conflictBehavior,
@ -161,6 +171,33 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts
return uploadSession.UploadURL, nil return uploadSession.UploadURL, nil
} }
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
siteUrlParsed, err := url.Parse(siteUrl)
if err != nil {
return "", err
}
hostName := siteUrlParsed.Hostname()
relativePath := strings.Trim(siteUrlParsed.Path, "/")
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if reqErr != nil {
return "", reqErr
}
var (
decodeErr error
siteInfo Site
)
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
if decodeErr != nil {
return "", decodeErr
}
return siteInfo.ID, nil
}
// GetUploadSessionStatus 查询上传会话状态 // GetUploadSessionStatus 查询上传会话状态
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) { func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200) res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
@ -300,7 +337,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read
} }
dst = strings.TrimPrefix(dst, "/") dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/content") requestURL := client.getRequestURL("root:/" + dst + ":/content")
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior) requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)), res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
@ -357,7 +394,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
// 由于API限制最多删除20个 // 由于API限制最多删除20个
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) { func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
body := client.makeBatchDeleteRequestsBody(dst) body := client.makeBatchDeleteRequestsBody(dst)
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch"), body, 200) res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
WithDriverResource(false)), body, 200)
if err != nil { if err != nil {
return dst, err return dst, err
} }
@ -396,7 +434,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
} }
for i, v := range files { for i, v := range files {
v = strings.TrimPrefix(v, "/") v = strings.TrimPrefix(v, "/")
filePath, _ := url.Parse("/me/drive/root:/") filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
filePath.Path = path.Join(filePath.Path, v) filePath.Path = path.Join(filePath.Path, v)
req.Requests[i] = BatchRequest{ req.Requests[i] = BatchRequest{
ID: v, ID: v,
@ -418,10 +456,10 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s
) )
if client.Endpoints.isInChina { if client.Endpoints.isInChina {
cropOption = "large" cropOption = "large"
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails/0") + "/" + cropOption requestURL = client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/" + cropOption
} else { } else {
cropOption = fmt.Sprintf("c%dx%d_Crop", w, h) cropOption = fmt.Sprintf("c%dx%d_Crop", w, h)
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails") + "?select=" + cropOption requestURL = client.getRequestURL("root:/"+dst+":/thumbnails") + "?select=" + cropOption
} }
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200) res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)

@ -167,6 +167,82 @@ func TestClient_GetRequestURL(t *testing.T) {
client.Endpoints.EndpointURL = string([]byte{0x7f}) client.Endpoints.EndpointURL = string([]byte{0x7f})
asserts.Equal("", client.getRequestURL("123")) asserts.Equal("", client.getRequestURL("123"))
} }
// 使用DriverResource
{
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123"))
}
// 不使用DriverResource
{
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false)))
}
}
func TestClient_GetSiteIDByURL(t *testing.T) {
asserts := assert.New(t)
client, _ := NewClient(&model.Policy{})
client.Credential.AccessToken = "AccessToken"
// 请求失败
{
client.Credential.ExpiresIn = 0
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
asserts.Error(err)
asserts.Empty(res)
}
// 返回未知响应
{
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`???`)),
},
})
client.Request = clientMock
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Empty(res)
}
// 返回正常
{
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)),
},
})
client.Request = clientMock
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
clientMock.AssertExpectations(t)
asserts.NoError(err)
asserts.NotEmpty(res)
asserts.Equal("123321", res)
}
} }
func TestClient_Meta(t *testing.T) { func TestClient_Meta(t *testing.T) {

@ -37,6 +37,7 @@ type Endpoints struct {
OAuthEndpoints *oauthEndpoint OAuthEndpoints *oauthEndpoint
EndpointURL string // 接口请求的基URL EndpointURL string // 接口请求的基URL
isInChina bool // 是否为世纪互联 isInChina bool // 是否为世纪互联
DriverResource string // 要使用的驱动器
} }
// NewClient 根据存储策略获取新的client // NewClient 根据存储策略获取新的client
@ -45,6 +46,7 @@ func NewClient(policy *model.Policy) (*Client, error) {
Endpoints: &Endpoints{ Endpoints: &Endpoints{
OAuthURL: policy.BaseURL, OAuthURL: policy.BaseURL,
EndpointURL: policy.Server, EndpointURL: policy.Server,
DriverResource: policy.OptionsSerialized.OdDriver,
}, },
Credential: &Credential{ Credential: &Credential{
RefreshToken: policy.AccessKey, RefreshToken: policy.AccessKey,
@ -56,6 +58,10 @@ func NewClient(policy *model.Policy) (*Client, error) {
Request: request.HTTPClient{}, Request: request.HTTPClient{},
} }
if client.Endpoints.DriverResource == "" {
client.Endpoints.DriverResource = "me/drive"
}
oauthBase := client.getOAuthEndpoint() oauthBase := client.getOAuthEndpoint()
if oauthBase == nil { if oauthBase == nil {
return nil, ErrAuthEndpoint return nil, ErrAuthEndpoint

@ -160,7 +160,8 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
client.Credential = credential client.Credential = credential
// 更新存储策略的 RefreshToken // 更新存储策略的 RefreshToken
client.Policy.UpdateAccessKey(credential.RefreshToken) client.Policy.AccessKey = credential.RefreshToken
client.Policy.SaveAndClearCache()
// 更新缓存 // 更新缓存
cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) cache.Set("onedrive_"+client.ClientID, *credential, int(expires))

@ -13,6 +13,7 @@ type options struct {
refreshToken string refreshToken string
conflictBehavior string conflictBehavior string
expires time.Time expires time.Time
useDriverResource bool
} }
type optionFunc func(*options) type optionFunc func(*options)
@ -38,6 +39,13 @@ func WithConflictBehavior(t string) Option {
}) })
} }
// WithConflictBehavior 设置文件重名后的处理方式
func WithDriverResource(t bool) Option {
return optionFunc(func(o *options) {
o.useDriverResource = t
})
}
func (f optionFunc) apply(o *options) { func (f optionFunc) apply(o *options) {
f(o) f(o)
} }
@ -45,6 +53,7 @@ func (f optionFunc) apply(o *options) {
func newDefaultOption() *options { func newDefaultOption() *options {
return &options{ return &options{
conflictBehavior: "fail", conflictBehavior: "fail",
useDriverResource: true,
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour), expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
} }
} }

@ -131,6 +131,15 @@ type OAuthError struct {
CorrelationID string `json:"correlation_id"` CorrelationID string `json:"correlation_id"`
} }
// Site SharePoint 站点信息
type Site struct {
Description string `json:"description"`
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
WebUrl string `json:"webUrl"`
}
func init() { func init() {
gob.Register(Credential{}) gob.Register(Credential{})
} }

@ -2,6 +2,7 @@ package callback
import ( import (
"context" "context"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
@ -41,17 +42,42 @@ func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response {
return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err) return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err)
} }
credential, err := client.ObtainToken(context.Background(), onedrive.WithCode(service.Code)) credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code))
if err != nil { if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err) return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err)
} }
// 更新存储策略的 RefreshToken // 更新存储策略的 RefreshToken
if err := client.Policy.UpdateAccessKey(credential.RefreshToken); err != nil { client.Policy.AccessKey = credential.RefreshToken
if err := client.Policy.SaveAndClearCache(); err != nil {
return serializer.DBErr("无法更新 RefreshToken", err) return serializer.DBErr("无法更新 RefreshToken", err)
} }
cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_")
if client.Policy.OptionsSerialized.OdDriver != "" {
if err := querySharePointSiteID(c, client.Policy); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "无法查询 SharePoint 站点 ID", err)
}
}
return serializer.Response{} return serializer.Response{}
} }
func querySharePointSiteID(ctx context.Context, policy *model.Policy) error {
client, err := onedrive.NewClient(policy)
if err != nil {
return err
}
id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver)
if err != nil {
return err
}
client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id)
if err := client.Policy.SaveAndClearCache(); err != nil {
return err
}
return nil
}

Loading…
Cancel
Save