|
|
|
@ -2,8 +2,11 @@ package middleware
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"github.com/HFO4/cloudreve/pkg/conf"
|
|
|
|
|
"github.com/HFO4/cloudreve/pkg/util"
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
"testing"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -28,3 +31,41 @@ func TestSession(t *testing.T) {
|
|
|
|
|
func emptyFunc() gin.HandlerFunc {
|
|
|
|
|
return func(c *gin.Context) {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCSRFInit(t *testing.T) {
|
|
|
|
|
asserts := assert.New(t)
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
sessionFunc := Session("233")
|
|
|
|
|
{
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
|
sessionFunc(c)
|
|
|
|
|
CSRFInit()(c)
|
|
|
|
|
asserts.True(util.GetSession(c, "CSRF").(bool))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCSRFCheck(t *testing.T) {
|
|
|
|
|
asserts := assert.New(t)
|
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
|
sessionFunc := Session("233")
|
|
|
|
|
|
|
|
|
|
// 通过检查
|
|
|
|
|
{
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
|
sessionFunc(c)
|
|
|
|
|
CSRFInit()(c)
|
|
|
|
|
CSRFCheck()(c)
|
|
|
|
|
asserts.False(c.IsAborted())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 未通过检查
|
|
|
|
|
{
|
|
|
|
|
c, _ := gin.CreateTestContext(rec)
|
|
|
|
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
|
|
|
|
sessionFunc(c)
|
|
|
|
|
CSRFCheck()(c)
|
|
|
|
|
asserts.True(c.IsAborted())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|