diff --git a/middleware/wopi_test.go b/middleware/wopi_test.go
new file mode 100644
index 0000000..c6ca327
--- /dev/null
+++ b/middleware/wopi_test.go
@@ -0,0 +1,112 @@
+package middleware
+
+import (
+ "errors"
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
+ "github.com/cloudreve/Cloudreve/v3/pkg/wopi"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestWopiWriteAccess(t *testing.T) {
+ asserts := assert.New(t)
+ rec := httptest.NewRecorder()
+ testFunc := WopiWriteAccess()
+
+ // deny preview only session
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
+ testFunc(c)
+ asserts.True(c.IsAborted())
+ }
+
+ // pass
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
+ testFunc(c)
+ asserts.False(c.IsAborted())
+ }
+}
+
+func TestWopiAccessValidation(t *testing.T) {
+ asserts := assert.New(t)
+ rec := httptest.NewRecorder()
+ mockWopi := &wopimock.WopiClientMock{}
+ mockCache := cache.NewMemoStore()
+ testFunc := WopiAccessValidation(mockWopi, mockCache)
+
+ // malformed access token
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.AddParam(wopi.AccessTokenQuery, "000")
+ testFunc(c)
+ asserts.True(c.IsAborted())
+ }
+
+ // session key not exist
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
+ query := c.Request.URL.Query()
+ query.Set(wopi.AccessTokenQuery, "sessionID.key")
+ c.Request.URL.RawQuery = query.Encode()
+ testFunc(c)
+ asserts.True(c.IsAborted())
+ }
+
+ // user key not exist
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
+ query := c.Request.URL.Query()
+ query.Set(wopi.AccessTokenQuery, "sessionID.key")
+ c.Request.URL.RawQuery = query.Encode()
+ mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
+ mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
+ testFunc(c)
+ asserts.True(c.IsAborted())
+ asserts.NoError(mock.ExpectationsWereMet())
+ }
+
+ // file not found
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
+ query := c.Request.URL.Query()
+ query.Set(wopi.AccessTokenQuery, "sessionID.key")
+ c.Request.URL.RawQuery = query.Encode()
+ mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
+ mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
+ c.Set("object_id", uint(0))
+ testFunc(c)
+ asserts.True(c.IsAborted())
+ asserts.NoError(mock.ExpectationsWereMet())
+ }
+
+ // all pass
+ {
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
+ query := c.Request.URL.Query()
+ query.Set(wopi.AccessTokenQuery, "sessionID.key")
+ c.Request.URL.RawQuery = query.Encode()
+ mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
+ mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
+ c.Set("object_id", uint(1))
+ testFunc(c)
+ asserts.False(c.IsAborted())
+ asserts.NoError(mock.ExpectationsWereMet())
+ asserts.NotPanics(func() {
+ c.MustGet(WopiSessionCtx)
+ })
+ asserts.NotPanics(func() {
+ c.MustGet("user")
+ })
+ }
+}
diff --git a/pkg/mocks/cachemock/mock.go b/pkg/mocks/cachemock/mock.go
new file mode 100644
index 0000000..98fe78c
--- /dev/null
+++ b/pkg/mocks/cachemock/mock.go
@@ -0,0 +1,29 @@
+package cachemock
+
+import "github.com/stretchr/testify/mock"
+
+type CacheClientMock struct {
+ mock.Mock
+}
+
+func (c CacheClientMock) Set(key string, value interface{}, ttl int) error {
+ return c.Called(key, value, ttl).Error(0)
+}
+
+func (c CacheClientMock) Get(key string) (interface{}, bool) {
+ args := c.Called(key)
+ return args.Get(0), args.Bool(1)
+}
+
+func (c CacheClientMock) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
+ args := c.Called(keys, prefix)
+ return args.Get(0).(map[string]interface{}), args.Get(1).([]string)
+}
+
+func (c CacheClientMock) Sets(values map[string]interface{}, prefix string) error {
+ return c.Called(values).Error(0)
+}
+
+func (c CacheClientMock) Delete(keys []string, prefix string) error {
+ return c.Called(keys, prefix).Error(0)
+}
diff --git a/pkg/mocks/wopimock/mock.go b/pkg/mocks/wopimock/mock.go
new file mode 100644
index 0000000..a11eb4f
--- /dev/null
+++ b/pkg/mocks/wopimock/mock.go
@@ -0,0 +1,21 @@
+package wopimock
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/wopi"
+ "github.com/stretchr/testify/mock"
+)
+
+type WopiClientMock struct {
+ mock.Mock
+}
+
+func (w *WopiClientMock) NewSession(user *model.User, file *model.File, action wopi.ActonType) (*wopi.Session, error) {
+ args := w.Called(user, file, action)
+ return args.Get(0).(*wopi.Session), args.Error(1)
+}
+
+func (w *WopiClientMock) AvailableExts() []string {
+ args := w.Called()
+ return args.Get(0).([]string)
+}
diff --git a/pkg/wopi/discovery.go b/pkg/wopi/discovery.go
index d4f480f..a9b6944 100644
--- a/pkg/wopi/discovery.go
+++ b/pkg/wopi/discovery.go
@@ -3,7 +3,6 @@ package wopi
import (
"encoding/xml"
"fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strings"
@@ -62,7 +61,7 @@ func (c *client) refreshDiscovery() error {
c.mu.Lock()
defer c.mu.Unlock()
- cached, exist := cache.Get(DiscoverResponseCacheKey)
+ cached, exist := c.cache.Get(DiscoverResponseCacheKey)
if exist {
cachedDiscovery := cached.(WopiDiscovery)
c.discovery = &cachedDiscovery
@@ -70,14 +69,14 @@ func (c *client) refreshDiscovery() error {
res, err := c.http.Request("GET", c.config.discoveryEndpoint.String(), nil).
CheckHTTPResponse(http.StatusOK).GetResponse()
if err != nil {
- return fmt.Errorf("failed to request discovery endpoint: %s", err)
+ return fmt.Errorf("failed to request discovery endpoint: %w", err)
}
if err := xml.Unmarshal([]byte(res), &c.discovery); err != nil {
- return fmt.Errorf("failed to parse response discovery endpoint: %s", err)
+ return fmt.Errorf("failed to parse response discovery endpoint: %w", err)
}
- if err := cache.Set(DiscoverResponseCacheKey, *c.discovery, DiscoverRefreshDuration); err != nil {
+ if err := c.cache.Set(DiscoverResponseCacheKey, *c.discovery, DiscoverRefreshDuration); err != nil {
return err
}
}
diff --git a/pkg/wopi/discovery_test.go b/pkg/wopi/discovery_test.go
index 4e9bd98..8092384 100644
--- a/pkg/wopi/discovery_test.go
+++ b/pkg/wopi/discovery_test.go
@@ -1,25 +1,129 @@
package wopi
import (
- model "github.com/cloudreve/Cloudreve/v3/models"
+ "errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/stretchr/testify/assert"
+ testMock "github.com/stretchr/testify/mock"
+ "io"
+ "net/http"
"net/url"
+ "strings"
"testing"
)
-func TestDiscovery(t *testing.T) {
+func TestClient_AvailableExts(t *testing.T) {
a := assert.New(t)
endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
client := &client{
- cache: cache.Store,
- http: request.NewClient(),
+ cache: cache.NewMemoStore(),
config: config{
discoveryEndpoint: endpoint,
},
}
- a.NoError(client.refreshDiscovery())
- client.NewSession(nil, &model.File{Name: "123.pptx"}, ActionPreview)
+ // Discovery failed
+ {
+ expectedErr := errors.New("error")
+ mockHttp := &requestmock.RequestMock{}
+ client.http = mockHttp
+ mockHttp.On(
+ "Request",
+ "GET",
+ endpoint.String(),
+ testMock.Anything,
+ testMock.Anything,
+ ).Return(&request.Response{
+ Err: expectedErr,
+ })
+ res := client.AvailableExts()
+ a.Empty(res)
+ mockHttp.AssertExpectations(t)
+ }
+
+ // pass
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {
+ string(ActionPreviewFallback): Action{},
+ },
+ ".ppt": {},
+ ".xls": {
+ "not_supported": Action{},
+ },
+ }
+ res := client.AvailableExts()
+ a.Len(res, 1)
+ a.Equal("doc", res[0])
+ }
+}
+
+func TestClient_RefreshDiscovery(t *testing.T) {
+ a := assert.New(t)
+ endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
+ client := &client{
+ cache: cache.NewMemoStore(),
+ config: config{
+ discoveryEndpoint: endpoint,
+ },
+ }
+
+ // cache hit
+ {
+ client.cache.Set(DiscoverResponseCacheKey, WopiDiscovery{Text: "test"}, 0)
+ a.NoError(client.checkDiscovery())
+ a.Equal("test", client.discovery.Text)
+ client.discovery = &WopiDiscovery{}
+ client.cache.Delete([]string{DiscoverResponseCacheKey}, "")
+ }
+
+ // malformed xml
+ {
+ mockHttp := &requestmock.RequestMock{}
+ client.http = mockHttp
+ mockHttp.On(
+ "Request",
+ "GET",
+ endpoint.String(),
+ testMock.Anything,
+ testMock.Anything,
+ ).Return(&request.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader(`{"code":203}`)),
+ },
+ })
+ res := client.refreshDiscovery()
+ a.ErrorContains(res, "failed to parse")
+ mockHttp.AssertExpectations(t)
+ }
+
+ // all pass
+ {
+ testResponse := `
+`
+ mockHttp := &requestmock.RequestMock{}
+ client.http = mockHttp
+ mockHttp.On(
+ "Request",
+ "GET",
+ endpoint.String(),
+ testMock.Anything,
+ testMock.Anything,
+ ).Return(&request.Response{
+ Response: &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader(testResponse)),
+ },
+ })
+ res := client.refreshDiscovery()
+ a.NoError(res, res)
+ a.NotEmpty(client.actions[".docx"])
+ a.NotEmpty(client.actions[".docx"][string(ActionPreview)])
+ a.NotEmpty(client.actions[".docx"][string(ActionEdit)])
+ mockHttp.AssertExpectations(t)
+ }
}
diff --git a/pkg/wopi/wopi.go b/pkg/wopi/wopi.go
index bac734d..bf92985 100644
--- a/pkg/wopi/wopi.go
+++ b/pkg/wopi/wopi.go
@@ -165,9 +165,9 @@ func (c *client) NewSession(user *model.User, file *model.File, action ActonType
UserID: user.ID,
Action: action,
}
- err = cache.Set(SessionCachePrefix+sessionID.String(), *session, ttl)
+ err = c.cache.Set(SessionCachePrefix+sessionID.String(), *session, ttl)
if err != nil {
- return nil, fmt.Errorf("failed to create document session: %s", err)
+ return nil, fmt.Errorf("failed to create document session: %w", err)
}
sessionRes := &Session{
diff --git a/pkg/wopi/wopi_test.go b/pkg/wopi/wopi_test.go
new file mode 100644
index 0000000..e0de5f9
--- /dev/null
+++ b/pkg/wopi/wopi_test.go
@@ -0,0 +1,184 @@
+package wopi
+
+import (
+ "database/sql"
+ "errors"
+ "github.com/DATA-DOG/go-sqlmock"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/mocks/cachemock"
+ "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
+ "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "github.com/jinzhu/gorm"
+ "github.com/stretchr/testify/assert"
+ testMock "github.com/stretchr/testify/mock"
+ "net/url"
+ "testing"
+)
+
+var mock sqlmock.Sqlmock
+
+// TestMain 初始化数据库Mock
+func TestMain(m *testing.M) {
+ var db *sql.DB
+ var err error
+ db, mock, err = sqlmock.New()
+ if err != nil {
+ panic("An error was not expected when opening a stub database connection")
+ }
+ model.DB, _ = gorm.Open("mysql", db)
+ defer db.Close()
+ m.Run()
+}
+
+func TestNewSession(t *testing.T) {
+ a := assert.New(t)
+ endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
+ client := &client{
+ cache: cache.NewMemoStore(),
+ config: config{
+ discoveryEndpoint: endpoint,
+ },
+ }
+
+ // Discovery failed
+ {
+ expectedErr := errors.New("error")
+ mockHttp := &requestmock.RequestMock{}
+ client.http = mockHttp
+ mockHttp.On(
+ "Request",
+ "GET",
+ endpoint.String(),
+ testMock.Anything,
+ testMock.Anything,
+ ).Return(&request.Response{
+ Err: expectedErr,
+ })
+ res, err := client.NewSession(&model.User{}, &model.File{}, ActionPreview)
+ a.Nil(res)
+ a.ErrorIs(err, expectedErr)
+ mockHttp.AssertExpectations(t)
+ }
+
+ // not supported ext
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = make(map[string]map[string]Action)
+ res, err := client.NewSession(&model.User{}, &model.File{}, ActionPreview)
+ a.Nil(res)
+ a.ErrorIs(err, ErrActionNotSupported)
+ }
+
+ // preferred action not supported
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {},
+ }
+ res, err := client.NewSession(&model.User{}, &model.File{Name: "1.doc"}, ActionPreview)
+ a.Nil(res)
+ a.ErrorIs(err, ErrActionNotSupported)
+ }
+
+ // src url cannot be parsed
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {
+ string(ActionPreviewFallback): Action{
+ Urlsrc: string([]byte{0x7f}),
+ },
+ },
+ }
+ res, err := client.NewSession(&model.User{}, &model.File{Name: "1.doc"}, ActionEdit)
+ a.Nil(res)
+ a.ErrorContains(err, "invalid control character in URL")
+ }
+
+ // all pass - default placeholder
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {
+ string(ActionPreviewFallback): Action{
+ Urlsrc: "https://doc.com/doc",
+ },
+ },
+ }
+ res, err := client.NewSession(&model.User{}, &model.File{Name: "1.doc"}, ActionEdit)
+ a.NotNil(res)
+ a.NoError(err)
+ resUrl := res.ActionURL.String()
+ a.Contains(resUrl, wopiSrcParamDefault)
+ }
+
+ // all pass - with placeholders
+ {
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {
+ string(ActionPreviewFallback): Action{
+ Urlsrc: "https://doc.com/doc?origin=preserved&",
+ },
+ },
+ }
+ res, err := client.NewSession(&model.User{}, &model.File{Name: "1.doc"}, ActionEdit)
+ a.NotNil(res)
+ a.NoError(err)
+ resUrl := res.ActionURL.String()
+ a.Contains(resUrl, "origin=preserved")
+ a.Contains(resUrl, "dc=lng")
+ a.Contains(resUrl, "src=")
+ a.NotContains(resUrl, "notsuported")
+ }
+
+ // cache operation failed
+ {
+ mockCache := &cachemock.CacheClientMock{}
+ expectedErr := errors.New("error")
+ client.cache = mockCache
+ client.discovery = &WopiDiscovery{}
+ client.actions = map[string]map[string]Action{
+ ".doc": {
+ string(ActionPreviewFallback): Action{
+ Urlsrc: "https://doc.com/doc",
+ },
+ },
+ }
+ mockCache.On("Set", testMock.Anything, testMock.Anything, testMock.Anything).Return(expectedErr)
+ res, err := client.NewSession(&model.User{}, &model.File{Name: "1.doc"}, ActionEdit)
+ a.Nil(res)
+ a.ErrorIs(err, expectedErr)
+ }
+}
+
+func TestInit(t *testing.T) {
+ a := assert.New(t)
+
+ // not enabled
+ {
+ a.Nil(Default)
+ Default = &client{}
+ Init()
+ a.Nil(Default)
+ }
+
+ // throw error
+ {
+ a.Nil(Default)
+ cache.Set("setting_wopi_enabled", "1", 0)
+ cache.Set("setting_wopi_endpoint", string([]byte{0x7f}), 0)
+ Init()
+ a.Nil(Default)
+ }
+
+ // all pass
+ {
+ a.Nil(Default)
+ cache.Set("setting_wopi_enabled", "1", 0)
+ cache.Set("setting_wopi_endpoint", "", 0)
+ Init()
+ a.NotNil(Default)
+ }
+}