From 32c023210554be518528b5296a36f6d1b5a1d197 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Tue, 17 Mar 2020 15:57:38 +0800 Subject: [PATCH] Fix: file preview URL in share page should not be accessed directly --- middleware/session.go | 22 ++++++++++++++++++++ middleware/session_test.go | 41 ++++++++++++++++++++++++++++++++++++++ routers/router.go | 3 ++- 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/middleware/session.go b/middleware/session.go index 3a80e8b..06b56e5 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -2,6 +2,7 @@ package middleware import ( "github.com/HFO4/cloudreve/pkg/conf" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/memstore" @@ -32,3 +33,24 @@ func Session(secret string) gin.HandlerFunc { Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) return sessions.Sessions("cloudreve-session", Store) } + +// CSRFInit 初始化CSRF标记 +func CSRFInit() gin.HandlerFunc { + return func(c *gin.Context) { + util.SetSession(c, map[string]interface{}{"CSRF": true}) + c.Next() + } +} + +// CSRFCheck 检查CSRF标记 +func CSRFCheck() gin.HandlerFunc { + return func(c *gin.Context) { + if check, ok := util.GetSession(c, "CSRF").(bool); ok && check { + c.Next() + return + } + + c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil)) + c.Abort() + } +} diff --git a/middleware/session_test.go b/middleware/session_test.go index af0eb96..cc4c23f 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -2,8 +2,11 @@ package middleware import ( "github.com/HFO4/cloudreve/pkg/conf" + "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" "testing" ) @@ -28,3 +31,41 @@ func TestSession(t *testing.T) { func emptyFunc() gin.HandlerFunc { return func(c *gin.Context) {} } + +func TestCSRFInit(t *testing.T) { + asserts := assert.New(t) + rec := httptest.NewRecorder() + sessionFunc := Session("233") + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("GET", "/test", nil) + sessionFunc(c) + CSRFInit()(c) + asserts.True(util.GetSession(c, "CSRF").(bool)) + } +} + +func TestCSRFCheck(t *testing.T) { + asserts := assert.New(t) + rec := httptest.NewRecorder() + sessionFunc := Session("233") + + // 通过检查 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("GET", "/test", nil) + sessionFunc(c) + CSRFInit()(c) + CSRFCheck()(c) + asserts.False(c.IsAborted()) + } + + // 未通过检查 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("GET", "/test", nil) + sessionFunc(c) + CSRFCheck()(c) + asserts.True(c.IsAborted()) + } +} diff --git a/routers/router.go b/routers/router.go index bb43135..29a5576 100644 --- a/routers/router.go +++ b/routers/router.go @@ -107,7 +107,7 @@ func InitMasterRouter() *gin.Engine { // 验证码 site.GET("captcha", controllers.Captcha) // 站点全局配置 - site.GET("config", controllers.SiteConfig) + site.GET("config", middleware.CSRFInit(), controllers.SiteConfig) } // 用户相关路由 @@ -231,6 +231,7 @@ func InitMasterRouter() *gin.Engine { ) // 预览分享文件 share.GET("preview/:id", + middleware.CSRFCheck(), middleware.CheckShareUnlocked(), middleware.ShareCanPreview(), middleware.BeforeShareDownload(),