Test: auth middleware for WebDAV

pull/247/head
HFO4 5 years ago
parent cf90ab5a9a
commit fd7b6e33c8

@ -52,6 +52,7 @@ func AuthRequired() gin.HandlerFunc {
} }
// WebDAVAuth 验证WebDAV登录及权限 // WebDAVAuth 验证WebDAV登录及权限
// TODO 测试
func WebDAVAuth() gin.HandlerFunc { func WebDAVAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// OPTIONS 请求不需要鉴权否则Windows10下无法保存文档 // OPTIONS 请求不需要鉴权否则Windows10下无法保存文档

@ -90,3 +90,107 @@ func TestSignRequired(t *testing.T) {
SignRequiredFunc(c) SignRequiredFunc(c)
asserts.NotNil(c) asserts.NotNil(c)
} }
func TestWebDAVAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := WebDAVAuth()
// options请求跳过验证
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
AuthFunc(c)
}
// 请求HTTP Basic Auth
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
AuthFunc(c)
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
}
// 用户名不存在
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email"}),
)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
// 密码错误
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
//未启用 WebDAV
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
}
//正常
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), 200)
_, ok := c.Get("user")
asserts.True(ok)
}
}

@ -25,14 +25,18 @@ func Init() {
// Driver 键值缓存存储容器 // Driver 键值缓存存储容器
type Driver interface { type Driver interface {
// 设置值 // 设置值ttl为过期时间单位为秒
Set(key string, value interface{}, ttl int) error Set(key string, value interface{}, ttl int) error
// 取值
// 取值,并返回是否成功
Get(key string) (interface{}, bool) Get(key string) (interface{}, bool)
// 批量取值返回成功取值的map即不存在的值 // 批量取值返回成功取值的map即不存在的值
Gets(keys []string, prefix string) (map[string]interface{}, []string) Gets(keys []string, prefix string) (map[string]interface{}, []string)
// 批量设置值
// 批量设置值所有的key都会加上prefix前缀
Sets(values map[string]interface{}, prefix string) error Sets(values map[string]interface{}, prefix string) error
// 删除值 // 删除值
Delete(keys []string, prefix string) error Delete(keys []string, prefix string) error
} }

@ -604,6 +604,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil
mw := multistatusWriter{w: w} mw := multistatusWriter{w: w}
walkFn := func(reqPath string, info FileInfo, err error) error { walkFn := func(reqPath string, info FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
} }
@ -626,7 +627,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil
if err != nil { if err != nil {
return err return err
} }
href := path.Join(h.Prefix, strconv.FormatUint(uint64(fs.User.ID), 10), reqPath) href := path.Join(h.Prefix, reqPath)
if href != "/" && info.IsDir() { if href != "/" && info.IsDir() {
href += "/" href += "/"
} }

Loading…
Cancel
Save