Test: signRequired middleware

pull/247/head
HFO4 5 years ago
parent 297b507ca7
commit 9f26c0c8ab

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/auth"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -76,3 +77,16 @@ func TestAuthRequired(t *testing.T) {
AuthRequiredFunc(c) AuthRequiredFunc(c)
asserts.NotNil(c) asserts.NotNil(c)
} }
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired()
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
}

@ -54,7 +54,6 @@ func (folder *Folder) GetChildFiles() ([]File, error) {
// GetFilesByIDs 根据文件ID批量获取文件, // GetFilesByIDs 根据文件ID批量获取文件,
// UID为0表示忽略用户只根据文件ID检索 // UID为0表示忽略用户只根据文件ID检索
// TODO 测试
func GetFilesByIDs(ids []uint, uid uint) ([]File, error) { func GetFilesByIDs(ids []uint, uid uint) ([]File, error) {
var files []File var files []File
var result *gorm.DB var result *gorm.DB

@ -106,6 +106,17 @@ func TestGetFilesByIDs(t *testing.T) {
asserts.NoError(err) asserts.NoError(err)
asserts.Len(folders, 1) asserts.Len(folders, 1)
} }
// 忽略UID查找
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(folders, 1)
}
} }
func TestGetChildFilesOfFolders(t *testing.T) { func TestGetChildFilesOfFolders(t *testing.T) {

@ -22,17 +22,18 @@ type Auth interface {
Check(body string, sign string) error Check(body string, sign string) error
} }
// SignURI 对URI进行签名 // SignURI 对URI进行签名,签名只针对Path部分query部分不做验证
// TODO 测试 // TODO 测试
func SignURI(uri string, expires int64) (*url.URL, error) { func SignURI(uri string, expires int64) (*url.URL, error) {
// 生成签名
sign := General.Sign(uri, expires)
// 将签名加到URI中
base, err := url.Parse(uri) base, err := url.Parse(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 生成签名
sign := General.Sign(base.Path, expires)
// 将签名加到URI中
queries := base.Query() queries := base.Query()
queries.Set("sign", sign) queries.Set("sign", sign)
base.RawQuery = queries.Encode() base.RawQuery = queries.Encode()
@ -47,9 +48,8 @@ func CheckURI(url *url.URL) error {
sign := queries.Get("sign") sign := queries.Get("sign")
queries.Del("sign") queries.Del("sign")
url.RawQuery = queries.Encode() url.RawQuery = queries.Encode()
requestURI := url.RequestURI()
return General.Check(requestURI, sign) return General.Check(url.Path, sign)
} }
// Init 初始化通用鉴权器 // Init 初始化通用鉴权器

@ -0,0 +1,48 @@
package auth
import (
"github.com/HFO4/cloudreve/pkg/util"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestSignURI(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 成功
{
sign, err := SignURI("/api/v3/something?id=1", 0)
asserts.NoError(err)
queries := sign.Query()
asserts.Equal("1", queries.Get("id"))
asserts.NotEmpty(queries.Get("sign"))
}
// URI解码失败
{
sign, err := SignURI("://dg.;'f]gh./'", 0)
asserts.Error(err)
asserts.Nil(sign)
}
}
func TestCheckURI(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 成功
{
sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()+10)
asserts.NoError(err)
asserts.NoError(CheckURI(sign))
}
// 过期
{
sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()-1)
asserts.NoError(err)
asserts.Error(CheckURI(sign))
}
}

@ -6,15 +6,13 @@ import (
) )
// Store 缓存存储器 // Store 缓存存储器
var Store Driver var Store Driver = NewMemoStore()
// Init 初始化缓存 // Init 初始化缓存
func Init() { func Init() {
//Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") //Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0")
//return //return
if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode { if conf.RedisConfig.Server != "" && gin.Mode() == gin.TestMode {
Store = NewMemoStore()
} else {
Store = NewRedisStore( Store = NewRedisStore(
10, 10,
"tcp", "tcp",

Loading…
Cancel
Save