diff --git a/bootstrap/init.go b/bootstrap/init.go index 6463164..11fa45e 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -37,7 +37,7 @@ func Init(path string) { { "both", func() { - cache.Init() + cache.Init(conf.SystemConfig.Mode == "slave") }, }, { diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 8df967d..74c1219 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -10,7 +10,7 @@ import ( var Store Driver = NewMemoStore() // Init 初始化缓存 -func Init() { +func Init(isSlave bool) { if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode { Store = NewRedisStore( 10, @@ -21,7 +21,7 @@ func Init() { ) } - if conf.SystemConfig.Mode == "slave" { + if isSlave { err := Store.Sets(conf.OptionOverwrite, "setting_") if err != nil { util.Log().Warning("无法覆盖数据库设置: %s", err) diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go index d30a67f..a0c5cfc 100644 --- a/pkg/cache/driver_test.go +++ b/pkg/cache/driver_test.go @@ -56,6 +56,10 @@ func TestInit(t *testing.T) { asserts := assert.New(t) asserts.NotPanics(func() { - Init() + Init(false) + }) + + asserts.NotPanics(func() { + Init(true) }) } diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 49e2a48..79118b2 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -413,7 +413,6 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { return strings.NewReader(string(reqBodyEncoded)), nil } -// TODO: move to slave pkg // RemoteCallback 发送远程存储策略上传回调请求 func RemoteCallback(url string, body serializer.UploadCallback) error { callbackBody, err := json.Marshal(struct { diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go index 2580936..47f4bf2 100644 --- a/pkg/cluster/slave_test.go +++ b/pkg/cluster/slave_test.go @@ -1,8 +1,12 @@ package cluster import ( + "bytes" + "encoding/json" + "errors" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/stretchr/testify/assert" @@ -442,124 +446,114 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) { } } -//func TestRemoteCallback(t *testing.T) { -// asserts := assert.New(t) -// -// // 回调成功 -// { -// clientMock := request.ClientMock{} -// mockResp, _ := json.Marshal(serializer.Response{Code: 0}) -// clientMock.On( -// "Request", -// "POST", -// "http://test/test/url", -// testMock.Anything, -// testMock.Anything, -// ).Return(&request.Response{ -// Err: nil, -// Response: &http.Response{ -// StatusCode: 200, -// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), -// }, -// }) -// request.GeneralClient = clientMock -// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ -// SourceName: "source", -// }) -// asserts.NoError(resp) -// clientMock.AssertExpectations(t) -// } -// -// // 服务端返回业务错误 -// { -// clientMock := request.ClientMock{} -// mockResp, _ := json.Marshal(serializer.Response{Code: 401}) -// clientMock.On( -// "Request", -// "POST", -// "http://test/test/url", -// testMock.Anything, -// testMock.Anything, -// ).Return(&request.Response{ -// Err: nil, -// Response: &http.Response{ -// StatusCode: 200, -// Body: ioutil.NopCloser(bytes.NewReader(mockResp)), -// }, -// }) -// request.GeneralClient = clientMock -// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ -// SourceName: "source", -// }) -// asserts.EqualValues(401, resp.(serializer.AppError).Code) -// clientMock.AssertExpectations(t) -// } -// -// // 无法解析回调响应 -// { -// clientMock := request.ClientMock{} -// clientMock.On( -// "Request", -// "POST", -// "http://test/test/url", -// testMock.Anything, -// testMock.Anything, -// ).Return(&request.Response{ -// Err: nil, -// Response: &http.Response{ -// StatusCode: 200, -// Body: ioutil.NopCloser(strings.NewReader("mockResp")), -// }, -// }) -// request.GeneralClient = clientMock -// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ -// SourceName: "source", -// }) -// asserts.Error(resp) -// clientMock.AssertExpectations(t) -// } -// -// // HTTP状态码非200 -// { -// clientMock := request.ClientMock{} -// clientMock.On( -// "Request", -// "POST", -// "http://test/test/url", -// testMock.Anything, -// testMock.Anything, -// ).Return(&request.Response{ -// Err: nil, -// Response: &http.Response{ -// StatusCode: 404, -// Body: ioutil.NopCloser(strings.NewReader("mockResp")), -// }, -// }) -// request.GeneralClient = clientMock -// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ -// SourceName: "source", -// }) -// asserts.Error(resp) -// clientMock.AssertExpectations(t) -// } -// -// // 无法发起回调 -// { -// clientMock := request.ClientMock{} -// clientMock.On( -// "Request", -// "POST", -// "http://test/test/url", -// testMock.Anything, -// testMock.Anything, -// ).Return(&request.Response{ -// Err: errors.New("error"), -// }) -// request.GeneralClient = clientMock -// resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{ -// SourceName: "source", -// }) -// asserts.Error(resp) -// clientMock.AssertExpectations(t) -// } -//} +func TestRemoteCallback(t *testing.T) { + asserts := assert.New(t) + + // 回调成功 + { + clientMock := controllermock.RequestMock{} + mockResp, _ := json.Marshal(serializer.Response{Code: 0}) + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(mockResp)), + }, + }) + request.GeneralClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) + asserts.NoError(resp) + clientMock.AssertExpectations(t) + } + + // 服务端返回业务错误 + { + clientMock := controllermock.RequestMock{} + mockResp, _ := json.Marshal(serializer.Response{Code: 401}) + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader(mockResp)), + }, + }) + request.GeneralClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) + asserts.EqualValues(401, resp.(serializer.AppError).Code) + clientMock.AssertExpectations(t) + } + + // 无法解析回调响应 + { + clientMock := controllermock.RequestMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("mockResp")), + }, + }) + request.GeneralClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } + + // HTTP状态码非200 + { + clientMock := controllermock.RequestMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 404, + Body: ioutil.NopCloser(strings.NewReader("mockResp")), + }, + }) + request.GeneralClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } + + // 无法发起回调 + { + clientMock := controllermock.RequestMock{} + clientMock.On( + "Request", + "POST", + "http://test/test/url", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: errors.New("error"), + }) + request.GeneralClient = clientMock + resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) + asserts.Error(resp) + clientMock.AssertExpectations(t) + } +} diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index aa95a7e..6d186ed 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -56,7 +56,11 @@ User = root Password = root Host = 127.0.0.1:3306 Name = v3 -TablePrefix = v3_` +TablePrefix = v3_ + +[OptionOverwrite] +key=value +` err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) defer func() { err = os.Remove("testConf.ini") }() if err != nil { @@ -65,6 +69,7 @@ TablePrefix = v3_` asserts.NotPanics(func() { Init("testConf.ini") }) + asserts.Equal(OptionOverwrite["key"], "value") } func TestMapSection(t *testing.T) { diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go index 2b085f1..01c450b 100644 --- a/pkg/mocks/mocks.go +++ b/pkg/mocks/mocks.go @@ -7,11 +7,9 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" testMock "github.com/stretchr/testify/mock" - "io" ) type NodePoolMock struct { @@ -151,11 +149,3 @@ func (t TaskPoolMock) Add(num int) { func (t TaskPoolMock) Submit(job task.Job) { t.Called(job) } - -type RequestMock struct { - testMock.Mock -} - -func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - return r.Called(method, target, body, opts).Get(0).(*request.Response) -} diff --git a/pkg/mocks/requestmock/request.go b/pkg/mocks/requestmock/request.go new file mode 100644 index 0000000..41581b2 --- /dev/null +++ b/pkg/mocks/requestmock/request.go @@ -0,0 +1,15 @@ +package controllermock + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/stretchr/testify/mock" + "io" +) + +type RequestMock struct { + mock.Mock +} + +func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { + return r.Called(method, target, body, opts).Get(0).(*request.Response) +}