From e871f6e421360fdf5efc007380d19f375fe9f7a6 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Tue, 10 Dec 2019 11:13:33 +0800 Subject: [PATCH] Feat: HMAC auth and check --- main.go | 7 ++-- models/migration.go | 1 + pkg/auth/auth.go | 29 ++++++++++++++++ pkg/auth/hmac.go | 53 +++++++++++++++++++++++++++++ pkg/auth/hmac_test.go | 74 +++++++++++++++++++++++++++++++++++++++++ pkg/serializer/error.go | 2 ++ pkg/util/common.go | 5 +++ 7 files changed, 166 insertions(+), 5 deletions(-) create mode 100644 pkg/auth/auth.go create mode 100644 pkg/auth/hmac.go create mode 100644 pkg/auth/hmac_test.go diff --git a/main.go b/main.go index 1296782..9245afc 100644 --- a/main.go +++ b/main.go @@ -2,25 +2,22 @@ package main import ( "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/authn" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/routers" "github.com/gin-gonic/gin" - "math/rand" - "time" ) func init() { conf.Init("conf/conf.ini") model.Init() - rand.Seed(time.Now().UnixNano()) - // Debug 关闭时,切换为生产模式 if !conf.SystemConfig.Debug { gin.SetMode(gin.ReleaseMode) } - + auth.Init() authn.Init() } diff --git a/models/migration.go b/models/migration.go index 3c78a62..253e90e 100644 --- a/models/migration.go +++ b/models/migration.go @@ -146,6 +146,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "aria2_rpcurl", Value: `http://127.0.0.1:6800/`, Type: "aria2"}, {Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"}, {Name: "task_queue_token", Value: ``, Type: "task"}, + {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, } for _, value := range defaultSettings { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..5c2ed62 --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,29 @@ +package auth + +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/serializer" +) + +var ( + ErrAuthFailed = serializer.NewError(serializer.CodeNoRightErr, "鉴权失败", nil) + ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil) +) + +// General 通用的认证接口 +var General Auth + +// Auth 鉴权认证 +type Auth interface { + // 对给定Body进行签名,expires为0表示永不过期 + Sign(body string, expires int64) string + // 对给定Body和Sign进行检查 + Check(body string, sign string) error +} + +// Init 初始化通用鉴权器 +func Init() { + General = HMACAuth{ + SecretKey: []byte(model.GetSettingByName("secret_key")), + } +} diff --git a/pkg/auth/hmac.go b/pkg/auth/hmac.go new file mode 100644 index 0000000..a482cff --- /dev/null +++ b/pkg/auth/hmac.go @@ -0,0 +1,53 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "fmt" + "io" + "strconv" + "strings" + "time" +) + +// HMACAuth HMAC算法鉴权 +type HMACAuth struct { + SecretKey []byte +} + +// Sign 对给定Body生成expires后失效的签名 +func (auth HMACAuth) Sign(body string, expires int64) string { + h := hmac.New(sha256.New, auth.SecretKey) + expireTimeStamp := strconv.FormatInt(expires, 10) + _, err := io.WriteString(h, body+":"+expireTimeStamp) + if err != nil { + return "" + } + + return fmt.Sprintf("%x", h.Sum(nil)) + ":" + expireTimeStamp +} + +// Check 对给定Body和Sign进行鉴权,包括对expires的检查 +func (auth HMACAuth) Check(body string, sign string) error { + signSlice := strings.Split(sign, ":") + // 如果未携带expires字段 + if signSlice[len(signSlice)-1] == "" { + return ErrAuthFailed + } + + // 验证是否过期 + expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64) + if err != nil { + return ErrAuthFailed.WithError(err) + } + // 如果签名过期 + if expires < time.Now().Unix() && expires != 0 { + return ErrExpired + } + + // 验证签名 + if auth.Sign(body, expires) != sign { + return ErrAuthFailed + } + return nil +} diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go new file mode 100644 index 0000000..376a9c1 --- /dev/null +++ b/pkg/auth/hmac_test.go @@ -0,0 +1,74 @@ +package auth + +import ( + "database/sql" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "testing" +) + +var mock sqlmock.Sqlmock + +func TestMain(m *testing.M) { + // 设置gin为测试模式 + gin.SetMode(gin.TestMode) + + // 初始化sqlmock + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + + mockDB, _ := gorm.Open("mysql", db) + model.DB = mockDB + defer db.Close() + + m.Run() +} + +func TestHMACAuth_Sign(t *testing.T) { + asserts := assert.New(t) + auth := HMACAuth{ + SecretKey: []byte(util.RandStringRunes(256)), + } + + asserts.NotEmpty(auth.Sign("content", 0)) +} + +func TestHMACAuth_Check(t *testing.T) { + asserts := assert.New(t) + auth := HMACAuth{ + SecretKey: []byte(util.RandStringRunes(256)), + } + + // 正常,永不过期 + { + sign := auth.Sign("content", 0) + asserts.NoError(auth.Check("content", sign)) + } + + // 过期 + { + sign := auth.Sign("content", 1) + asserts.Error(auth.Check("content", sign)) + } + + // 签名格式错误 + { + sign := auth.Sign("content", 1) + asserts.Error(auth.Check("content", sign+":")) + } +} + +func TestInit(t *testing.T) { + asserts := assert.New(t) + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312")) + Init() + asserts.NoError(mock.ExpectationsWereMet()) +} diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index f91b2f7..a4fe19f 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -54,6 +54,8 @@ const ( CodeCreateFolderFailed = 40003 // CodeObjectExist 对象已存在 CodeObjectExist = 40004 + // CodeSignExpired 签名过期 + CodeSignExpired = 40005 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 diff --git a/pkg/util/common.go b/pkg/util/common.go index 32b5fd0..7305bb9 100644 --- a/pkg/util/common.go +++ b/pkg/util/common.go @@ -4,8 +4,13 @@ import ( "math/rand" "regexp" "strings" + "time" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + // RandStringRunes 返回随机字符串 func RandStringRunes(n int) string { var letterRunes = []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")