You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cloudreve/middleware/auth_test.go

606 lines
16 KiB

package middleware
import (
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestCurrentUser(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
//session为空
sessionFunc := Session("233")
sessionFunc(c)
CurrentUser()(c)
user, _ := c.Get("user")
asserts.Nil(user)
//session正确
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
util.SetSession(c, map[string]interface{}{"user_id": 1})
rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}).
AddRow(1, nil, "admin@cloudreve.org", "{}")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows)
CurrentUser()(c)
user, _ = c.Get("user")
asserts.NotNil(user)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestAuthRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
AuthRequiredFunc := AuthRequired()
// 未登录
AuthRequiredFunc(c)
asserts.NotNil(c)
// 类型错误
c.Set("user", 123)
AuthRequiredFunc(c)
asserts.NotNil(c)
// 正常
c.Set("user", &model.User{})
AuthRequiredFunc(c)
asserts.NotNil(c)
}
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := WebDAVAuth()
// options请求跳过验证
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
AuthFunc(c)
}
// 请求HTTP Basic Auth
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
AuthFunc(c)
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
}
// 用户名不存在
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email"}),
)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
// 密码错误
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
)
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
//未启用 WebDAV
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
}
//正常
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), 200)
_, ok := c.Get("user")
asserts.True(ok)
}
}
func TestUseUploadSession(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UseUploadSession("local")
// sessionID 为空
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
filesystem.UploadSessionCachePrefix+"testCallBackRemote",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "local"},
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestUploadCallbackCheck(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
// 上传会话不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testSessionNotExist"},
}
res := uploadCallbackCheck(c, "local")
a.Contains("上传会话不存在或已过期", res.Msg)
}
// 上传策略不一致
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testPolicyNotMatch"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
res := uploadCallbackCheck(c, "local")
a.Contains("Policy not supported", res.Msg)
}
// 用户不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testUserNotExist"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testUserNotExist",
serializer.UploadSession{
UID: 313,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
res := uploadCallbackCheck(c, "remote")
a.Contains("找不到用户", res.Msg)
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
a.False(ok)
}
}
func TestRemoteCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := RemoteCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.False(c.IsAborted())
}
// 签名错误
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestQiniuCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := QiniuCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
// 验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "1213")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestOSSCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OSSCallbackAuth()
// 签名验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
type fakeRead string
func (r fakeRead) Read(p []byte) (int, error) {
return 0, errors.New("error")
}
func TestUpyunCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UpyunCallbackAuth()
// 无法获取请求正文
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 正文MD5不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"123"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 签名不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
func TestOneDriveCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OneDriveCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "TestOneDriveCallbackAuth"},
}
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
AuthFunc(c)
select {
case <-res:
case <-time.After(time.Millisecond * 500):
asserts.Fail("mq message should be published")
}
asserts.False(c.IsAborted())
}
}
func TestIsAdmin(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := IsAdmin()
// 非管理员
{
c, _ := gin.CreateTestContext(rec)
c.Set("user", &model.User{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 是管理员
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
// 初始用户,非管理组
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 2
user.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
}