Fix: file preview URL in share page should not be accessed directly

pull/265/head
HFO4 5 years ago
parent 79f898e0a9
commit 32c0232105

@ -2,6 +2,7 @@ package middleware
import ( import (
"github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/memstore"
@ -32,3 +33,24 @@ func Session(secret string) gin.HandlerFunc {
Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"})
return sessions.Sessions("cloudreve-session", Store) return sessions.Sessions("cloudreve-session", Store)
} }
// CSRFInit 初始化CSRF标记
func CSRFInit() gin.HandlerFunc {
return func(c *gin.Context) {
util.SetSession(c, map[string]interface{}{"CSRF": true})
c.Next()
}
}
// CSRFCheck 检查CSRF标记
func CSRFCheck() gin.HandlerFunc {
return func(c *gin.Context) {
if check, ok := util.GetSession(c, "CSRF").(bool); ok && check {
c.Next()
return
}
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil))
c.Abort()
}
}

@ -2,8 +2,11 @@ package middleware
import ( import (
"github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing" "testing"
) )
@ -28,3 +31,41 @@ func TestSession(t *testing.T) {
func emptyFunc() gin.HandlerFunc { func emptyFunc() gin.HandlerFunc {
return func(c *gin.Context) {} return func(c *gin.Context) {}
} }
func TestCSRFInit(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
asserts.True(util.GetSession(c, "CSRF").(bool))
}
}
func TestCSRFCheck(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
// 通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
CSRFCheck()(c)
asserts.False(c.IsAborted())
}
// 未通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFCheck()(c)
asserts.True(c.IsAborted())
}
}

@ -107,7 +107,7 @@ func InitMasterRouter() *gin.Engine {
// 验证码 // 验证码
site.GET("captcha", controllers.Captcha) site.GET("captcha", controllers.Captcha)
// 站点全局配置 // 站点全局配置
site.GET("config", controllers.SiteConfig) site.GET("config", middleware.CSRFInit(), controllers.SiteConfig)
} }
// 用户相关路由 // 用户相关路由
@ -231,6 +231,7 @@ func InitMasterRouter() *gin.Engine {
) )
// 预览分享文件 // 预览分享文件
share.GET("preview/:id", share.GET("preview/:id",
middleware.CSRFCheck(),
middleware.CheckShareUnlocked(), middleware.CheckShareUnlocked(),
middleware.ShareCanPreview(), middleware.ShareCanPreview(),
middleware.BeforeShareDownload(), middleware.BeforeShareDownload(),

Loading…
Cancel
Save