parent
5b9de0e097
commit
64342fa88d
@ -0,0 +1,59 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RemoteCallback 发送远程存储策略上传回调请求
|
||||
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
||||
callbackBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err)
|
||||
}
|
||||
|
||||
resp := generalClient.Request(
|
||||
"POST",
|
||||
url,
|
||||
bytes.NewReader(callbackBody),
|
||||
WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
|
||||
WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
|
||||
)
|
||||
|
||||
if resp.Err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err)
|
||||
}
|
||||
|
||||
// 检查返回HTTP状态码
|
||||
if resp.Response.StatusCode != 200 {
|
||||
util.Log().Debug("服务端返回非正常状态码:%d", resp.Response.StatusCode)
|
||||
return serializer.NewError(serializer.CodeCallbackError, "服务端返回非正常状态码", nil)
|
||||
}
|
||||
|
||||
// 检查返回API状态码
|
||||
var response serializer.Response
|
||||
rawResp, err := ioutil.ReadAll(resp.Response.Body)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "无法读取响应正文", err)
|
||||
}
|
||||
|
||||
// 解析回调服务端响应
|
||||
err = json.Unmarshal(rawResp, &response)
|
||||
if err != nil {
|
||||
util.Log().Debug("无法解析回调服务端响应:%s", string(rawResp))
|
||||
return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,136 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRemoteCallback(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
// 回调成功
|
||||
{
|
||||
clientMock := ClientMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
generalClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
|
||||
SourceName: "source",
|
||||
})
|
||||
asserts.NoError(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 服务端返回业务错误
|
||||
{
|
||||
clientMock := ClientMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
generalClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
|
||||
SourceName: "source",
|
||||
})
|
||||
asserts.EqualValues(401, resp.(serializer.AppError).Code)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法解析回调响应
|
||||
{
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
generalClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
|
||||
SourceName: "source",
|
||||
})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// HTTP状态码非200
|
||||
{
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
generalClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
|
||||
SourceName: "source",
|
||||
})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法发起回调
|
||||
{
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(Response{
|
||||
Err: errors.New("error"),
|
||||
})
|
||||
generalClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{
|
||||
SourceName: "source",
|
||||
})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var generalClient Client = HTTPClient{}
|
||||
|
||||
// Response 请求的响应或错误信息
|
||||
type Response struct {
|
||||
Err error
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
// Client 请求客户端
|
||||
type Client interface {
|
||||
Request(method, target string, body io.Reader, opts ...Option) Response
|
||||
}
|
||||
|
||||
// HTTPClient 实现 Client 接口
|
||||
type HTTPClient struct {
|
||||
}
|
||||
|
||||
// Option 发送请求的额外设置
|
||||
type Option interface {
|
||||
apply(*options)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
timeout time.Duration
|
||||
header http.Header
|
||||
sign auth.Auth
|
||||
signTTL int64
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
header: http.Header{},
|
||||
timeout: time.Duration(30) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout 设置请求超时
|
||||
func WithTimeout(t time.Duration) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.timeout = t
|
||||
})
|
||||
}
|
||||
|
||||
// WithCredential 对请求进行签名
|
||||
func WithCredential(instance auth.Auth, ttl int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.sign = instance
|
||||
o.signTTL = ttl
|
||||
})
|
||||
}
|
||||
|
||||
// WithHeader 设置请求Header
|
||||
func WithHeader(header http.Header) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.header = header
|
||||
})
|
||||
}
|
||||
|
||||
// Request 发送HTTP请求
|
||||
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) Response {
|
||||
// 应用额外设置
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
// 创建请求客户端
|
||||
client := &http.Client{Timeout: options.timeout}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequest(method, target, body)
|
||||
if err != nil {
|
||||
return Response{Err: err}
|
||||
}
|
||||
|
||||
// 添加请求header
|
||||
req.Header = options.header
|
||||
|
||||
// 签名请求
|
||||
if options.sign != nil {
|
||||
auth.SignRequest(options.sign, req, time.Now().Unix()+options.signTTL)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return Response{Err: err}
|
||||
}
|
||||
|
||||
return Response{Err: nil, Response: resp}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ClientMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) Response {
|
||||
args := m.Called(method, target, body, opts)
|
||||
return args.Get(0).(Response)
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
options := newDefaultOption()
|
||||
WithTimeout(time.Duration(5) * time.Second).apply(options)
|
||||
asserts.Equal(time.Duration(5)*time.Second, options.timeout)
|
||||
}
|
||||
|
||||
func TestWithHeader(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
options := newDefaultOption()
|
||||
WithHeader(map[string][]string{"Origin": []string{"123"}}).apply(options)
|
||||
asserts.Equal(http.Header{"Origin": []string{"123"}}, options.header)
|
||||
}
|
||||
|
||||
func TestWithCredential(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
options := newDefaultOption()
|
||||
WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10).apply(options)
|
||||
asserts.Equal(auth.HMACAuth{SecretKey: []byte("123")}, options.sign)
|
||||
asserts.EqualValues(10, options.signTTL)
|
||||
}
|
||||
|
||||
func TestHTTPClient_Request(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
client := HTTPClient{}
|
||||
|
||||
// 正常
|
||||
{
|
||||
resp := client.Request(
|
||||
"GET",
|
||||
"http://cloudreveisnotexist.com",
|
||||
strings.NewReader(""),
|
||||
WithTimeout(time.Duration(1)*time.Microsecond),
|
||||
WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10),
|
||||
)
|
||||
asserts.Error(resp.Err)
|
||||
asserts.Nil(resp.Response)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in new issue