diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go index 231c790..ece638e 100644 --- a/pkg/filesystem/hooks_test.go +++ b/pkg/filesystem/hooks_test.go @@ -9,11 +9,16 @@ import ( "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/local" + "github.com/HFO4/cloudreve/pkg/request" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" + "io" + "io/ioutil" + "net/http" "os" + "strings" "testing" ) @@ -540,3 +545,51 @@ func TestHookSlaveUploadValidate(t *testing.T) { } } + +type ClientMock struct { + testMock.Mock +} + +func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) request.Response { + args := m.Called(method, target, body, opts) + return args.Get(0).(request.Response) +} + +func TestSlaveAfterUpload(t *testing.T) { + asserts := assert.New(t) + conf.SystemConfig.Mode = "slave" + fs, err := NewAnonymousFileSystem() + conf.SystemConfig.Mode = "master" + asserts.NoError(err) + + // 成功 + { + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://test/callbakc", + testMock.Anything, + testMock.Anything, + ).Return(request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), + }, + }) + request.GeneralClient = clientMock + ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{ + Size: 10, + VirtualPath: "/my", + Name: "test.txt", + }) + ctx = context.WithValue(ctx, fsctx.UploadPolicyCtx, serializer.UploadPolicy{ + CallbackURL: "http://test/callbakc", + }) + ctx = context.WithValue(ctx, fsctx.SavePathCtx, "/not_exist") + err := SlaveAfterUpload(ctx, fs) + clientMock.AssertExpectations(t) + asserts.NoError(err) + } +} diff --git a/pkg/request/callback.go b/pkg/request/callback.go index 24ac229..d438832 100644 --- a/pkg/request/callback.go +++ b/pkg/request/callback.go @@ -19,7 +19,7 @@ func RemoteCallback(url string, body serializer.UploadCallback) error { return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) } - resp := generalClient.Request( + resp := GeneralClient.Request( "POST", url, bytes.NewReader(callbackBody), diff --git a/pkg/request/callback_test.go b/pkg/request/callback_test.go index b8e527c..67e4d99 100644 --- a/pkg/request/callback_test.go +++ b/pkg/request/callback_test.go @@ -33,7 +33,7 @@ func TestRemoteCallback(t *testing.T) { Body: ioutil.NopCloser(bytes.NewReader(mockResp)), }, }) - generalClient = clientMock + GeneralClient = clientMock resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ SourceName: "source", }) @@ -58,7 +58,7 @@ func TestRemoteCallback(t *testing.T) { Body: ioutil.NopCloser(bytes.NewReader(mockResp)), }, }) - generalClient = clientMock + GeneralClient = clientMock resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ SourceName: "source", }) @@ -82,7 +82,7 @@ func TestRemoteCallback(t *testing.T) { Body: ioutil.NopCloser(strings.NewReader("mockResp")), }, }) - generalClient = clientMock + GeneralClient = clientMock resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ SourceName: "source", }) @@ -106,7 +106,7 @@ func TestRemoteCallback(t *testing.T) { Body: ioutil.NopCloser(strings.NewReader("mockResp")), }, }) - generalClient = clientMock + GeneralClient = clientMock resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ SourceName: "source", }) @@ -126,7 +126,7 @@ func TestRemoteCallback(t *testing.T) { ).Return(Response{ Err: errors.New("error"), }) - generalClient = clientMock + GeneralClient = clientMock resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ SourceName: "source", }) diff --git a/pkg/request/request.go b/pkg/request/request.go index fc4e9e1..e25585c 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -7,7 +7,7 @@ import ( "time" ) -var generalClient Client = HTTPClient{} +var GeneralClient Client = HTTPClient{} // Response 请求的响应或错误信息 type Response struct {