Feat: init request client with global options

pull/1040/head
HFO4 4 years ago
parent 3b47e314e9
commit 23d1839b29

@ -38,7 +38,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock() node.lock.Lock()
node.Model = nodeModel node.Model = nodeModel
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)} node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}
node.caller.Client = request.HTTPClient{} node.caller.Client = request.NewClient()
node.caller.parent = node node.caller.parent = node
node.Active = true node.Active = true
if node.close != nil { if node.close != nil {

@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) {
ClientID: policy.BucketName, ClientID: policy.BucketName,
ClientSecret: policy.SecretKey, ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OdRedirect, Redirect: policy.OptionsSerialized.OdRedirect,
Request: request.HTTPClient{}, Request: request.NewClient(),
} }
if client.Endpoints.DriverResource == "" { if client.Endpoints.DriverResource == "" {

@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
} }
// 获取公钥 // 获取公钥
client := request.HTTPClient{} client := request.NewClient()
body, err := client.Request("GET", string(pubURL), nil). body, err := client.Request("GET", string(pubURL), nil).
CheckHTTPResponse(200). CheckHTTPResponse(200).
GetResponse() GetResponse()

@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) {
BucketName: "test", BucketName: "test",
Server: "oss-cn-shanghai.aliyuncs.com", Server: "oss-cn-shanghai.aliyuncs.com",
}, },
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
} }
cache.Set("setting_preview_timeout", "3600", 0) cache.Set("setting_preview_timeout", "3600", 0)

@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
} }
// 获取文件数据流 // 获取文件数据流
client := request.HTTPClient{} client := request.NewClient()
resp, err := client.Request( resp, err := client.Request(
"GET", "GET",
downloadURL, downloadURL,

@ -3,5 +3,6 @@ package slave
import "errors" import "errors"
var ( var (
ErrNotImplemented = errors.New("This method of shadowed policy is not implemented") ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
) )

@ -5,7 +5,9 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io" "io"
"net/url" "net/url"
@ -16,6 +18,7 @@ type Driver struct {
node cluster.Node node cluster.Node
handler driver.Handler handler driver.Handler
policy *model.Policy policy *model.Policy
client request.Client
} }
// NewDriver 返回新的从机指派处理器 // NewDriver 返回新的从机指派处理器
@ -24,12 +27,16 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy)
node: node, node: node,
handler: handler, handler: handler,
policy: policy, policy: policy,
client: request.NewClient(request.WithMasterMeta()),
} }
} }
func (d Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { func (d Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
realBase, ok := ctx.Value(fsctx.SlaveSrcPath).(string)
if !ok {
return ErrSlaveSrcPathNotExist
}
panic("implement me")
} }
func (d Driver) Delete(ctx context.Context, files []string) ([]string, error) { func (d Driver) Delete(ctx context.Context, files []string) ([]string, error) {

@ -153,7 +153,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "remote": case "remote":
fs.Handler = remote.Driver{ fs.Handler = remote.Driver{
Policy: currentPolicy, Policy: currentPolicy,
Client: request.HTTPClient{}, Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)}, AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
} }
return nil return nil
@ -165,7 +165,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "oss": case "oss":
fs.Handler = oss.Driver{ fs.Handler = oss.Driver{
Policy: currentPolicy, Policy: currentPolicy,
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
} }
return nil return nil
case "upyun": case "upyun":
@ -178,7 +178,7 @@ func (fs *FileSystem) DispatchHandler() error {
fs.Handler = onedrive.Driver{ fs.Handler = onedrive.Driver{
Policy: currentPolicy, Policy: currentPolicy,
Client: client, Client: client,
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
} }
return err return err
case "cos": case "cos":
@ -192,7 +192,7 @@ func (fs *FileSystem) DispatchHandler() error {
SecretKey: currentPolicy.SecretKey, SecretKey: currentPolicy.SecretKey,
}, },
}), }),
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
} }
return nil return nil
case "s3": case "s3":

@ -0,0 +1,92 @@
package request
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"net/http"
"time"
)
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
timeout time.Duration
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
}
type optionFunc func(*options)
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
header: http.Header{},
timeout: time.Duration(30) * time.Second,
contentLength: -1,
}
}
// WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) {
o.timeout = t
})
}
// WithContext 设置请求上下文
func WithContext(c context.Context) Option {
return optionFunc(func(o *options) {
o.ctx = c
})
}
// WithCredential 对请求进行签名
func WithCredential(instance auth.Auth, ttl int64) Option {
return optionFunc(func(o *options) {
o.sign = instance
o.signTTL = ttl
})
}
// WithHeader 设置请求Header
func WithHeader(header http.Header) Option {
return optionFunc(func(o *options) {
for k, v := range header {
o.header[k] = v
}
})
}
// WithoutHeader 设置清除请求Header
func WithoutHeader(header []string) Option {
return optionFunc(func(o *options) {
for _, v := range header {
delete(o.header, v)
}
})
}
// WithContentLength 设置请求大小
func WithContentLength(s int64) Option {
return optionFunc(func(o *options) {
o.contentLength = s
})
}
// WithMasterMeta 请求时携带主机信息
func WithMasterMeta() Option {
return optionFunc(func(o *options) {
o.masterMeta = true
})
}

@ -1,14 +1,12 @@
package request package request
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
@ -33,105 +31,33 @@ type Client interface {
// HTTPClient 实现 Client 接口 // HTTPClient 实现 Client 接口
type HTTPClient struct { type HTTPClient struct {
options *options
} }
// Option 发送请求的额外设置 func NewClient(opts ...Option) Client {
type Option interface { client := &HTTPClient{
apply(*options) options: newDefaultOption(),
} }
type options struct { for _, o := range opts {
timeout time.Duration o.apply(client.options)
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
}
type optionFunc func(*options)
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
header: http.Header{},
timeout: time.Duration(30) * time.Second,
contentLength: -1,
}
}
// WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) {
o.timeout = t
})
}
// WithContext 设置请求上下文
func WithContext(c context.Context) Option {
return optionFunc(func(o *options) {
o.ctx = c
})
}
// WithCredential 对请求进行签名
func WithCredential(instance auth.Auth, ttl int64) Option {
return optionFunc(func(o *options) {
o.sign = instance
o.signTTL = ttl
})
}
// WithHeader 设置请求Header
func WithHeader(header http.Header) Option {
return optionFunc(func(o *options) {
for k, v := range header {
o.header[k] = v
}
})
}
// WithoutHeader 设置清除请求Header
func WithoutHeader(header []string) Option {
return optionFunc(func(o *options) {
for _, v := range header {
delete(o.header, v)
}
})
}
// WithContentLength 设置请求大小
func WithContentLength(s int64) Option {
return optionFunc(func(o *options) {
o.contentLength = s
})
} }
// WithMasterMeta 请求时携带主机信息 return client
func WithMasterMeta() Option {
return optionFunc(func(o *options) {
o.masterMeta = true
})
} }
// Request 发送HTTP请求 // Request 发送HTTP请求
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置 // 应用额外设置
options := newDefaultOption()
for _, o := range opts { for _, o := range opts {
o.apply(options) o.apply(c.options)
} }
// 创建请求客户端 // 创建请求客户端
client := &http.Client{Timeout: options.timeout} client := &http.Client{Timeout: c.options.timeout}
// size为0时将body设为nil // size为0时将body设为nil
if options.contentLength == 0 { if c.options.contentLength == 0 {
body = nil body = nil
} }
@ -140,8 +66,8 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
req *http.Request req *http.Request
err error err error
) )
if options.ctx != nil { if c.options.ctx != nil {
req, err = http.NewRequestWithContext(options.ctx, method, target, body) req, err = http.NewRequestWithContext(c.options.ctx, method, target, body)
} else { } else {
req, err = http.NewRequest(method, target, body) req, err = http.NewRequest(method, target, body)
} }
@ -150,21 +76,21 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
} }
// 添加请求相关设置 // 添加请求相关设置
req.Header = options.header req.Header = c.options.header
if options.masterMeta { if c.options.masterMeta {
req.Header.Add("X-Site-Url", model.GetSiteURL().String()) req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID")) req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
} }
if options.contentLength != -1 { if c.options.contentLength != -1 {
req.ContentLength = options.contentLength req.ContentLength = c.options.contentLength
} }
// 签名请求 // 签名请求
if options.sign != nil { if c.options.sign != nil {
auth.SignRequest(options.sign, req, options.signTTL) auth.SignRequest(c.options.sign, req, c.options.signTTL)
} }
// 发送请求 // 发送请求

@ -47,7 +47,7 @@ type masterInfo struct {
func Init() { func Init() {
DefaultController = &slaveController{ DefaultController = &slaveController{
masters: make(map[string]masterInfo), masters: make(map[string]masterInfo),
client: request.HTTPClient{}, client: request.NewClient(),
} }
gob.Register(rpc.StatusInfo{}) gob.Register(rpc.StatusInfo{})
} }

@ -24,7 +24,7 @@ func AdminSummary(c *gin.Context) {
// AdminNews 获取社区新闻 // AdminNews 获取社区新闻
func AdminNews(c *gin.Context) { func AdminNews(c *gin.Context) {
r := request.HTTPClient{} r := request.NewClient()
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil) res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
if res.Err == nil { if res.Err == nil {
io.Copy(c.Writer, res.Response.Body) io.Copy(c.Writer, res.Response.Body)

@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
case "oss": case "oss":
handler := oss.Driver{ handler := oss.Driver{
Policy: &policy, Policy: &policy,
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
} }
if err := handler.CORS(); err != nil { if err := handler.CORS(); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err)
@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
b := &cossdk.BaseURL{BucketURL: u} b := &cossdk.BaseURL{BucketURL: u}
handler := cos.Driver{ handler := cos.Driver{
Policy: &policy, Policy: &policy,
HTTPClient: request.HTTPClient{}, HTTPClient: request.NewClient(),
Client: cossdk.NewClient(b, &http.Client{ Client: cossdk.NewClient(b, &http.Client{
Transport: &cossdk.AuthorizationTransport{ Transport: &cossdk.AuthorizationTransport{
SecretID: policy.AccessKey, SecretID: policy.AccessKey,
@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response {
controller, _ := url.Parse("/api/v3/site/ping") controller, _ := url.Parse("/api/v3/site/ping")
r := request.HTTPClient{} r := request.NewClient()
res, err := r.Request( res, err := r.Request(
"GET", "GET",
master.ResolveReference(controller).String(), master.ResolveReference(controller).String(),
@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response {
} }
bodyByte, _ := json.Marshal(body) bodyByte, _ := json.Marshal(body)
r := request.HTTPClient{} r := request.NewClient()
res, err := r.Request( res, err := r.Request(
"POST", "POST",
slave.ResolveReference(controller).String(), slave.ResolveReference(controller).String(),

Loading…
Cancel
Save