fix update jwt-token version to avoid attackers to bypass intended access restrictions in situations with []string{} for m["aud"]

pull/27/head
senyu 4 years ago
parent b9bda9c513
commit df0ec6c804

@ -134,8 +134,8 @@ multiloginpolicy:
tokenpolicy: tokenpolicy:
accessSecret: "open_im_server" accessSecret: "open_im_server"
# Token effective time seconds as a unit # Token effective time seconds as a unit
#Seven days 7*24*60*60 #Seven days
accessExpire: 604800 accessExpire: 7
messagecallback: messagecallback:
callbackSwitch: false callbackSwitch: false

@ -19,24 +19,18 @@ var (
type Claims struct { type Claims struct {
UID string UID string
Platform string //login platform Platform string //login platform
jwt.StandardClaims jwt.RegisteredClaims
} }
func BuildClaims(uid, platform string, ttl int64) Claims { func BuildClaims(uid, platform string, ttl int64) Claims {
now := time.Now().Unix() now := time.Now()
//if ttl=-1 Permanent token
expiresAt := int64(-1)
if ttl != -1 {
expiresAt = now + ttl
}
return Claims{ return Claims{
UID: uid, UID: uid,
Platform: platform, Platform: platform,
StandardClaims: jwt.StandardClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: expiresAt, //Expiration time ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
IssuedAt: now, //Issuing time IssuedAt: jwt.NewNumericDate(now), //Issuing time
NotBefore: now, //Begin Effective time NotBefore: jwt.NewNumericDate(now), //Begin Effective time
}} }}
} }
@ -45,7 +39,7 @@ func CreateToken(userID string, platform int32) (string, int64, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret)) tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret))
return tokenString, claims.ExpiresAt, err return tokenString, claims.ExpiresAt.Time.Unix(), err
} }
func secret() jwt.Keyfunc { func secret() jwt.Keyfunc {
@ -105,7 +99,7 @@ func ParseToken(tokensString string) (claims *Claims, err error) {
exists = existsInterface.(int64) exists = existsInterface.(int64)
if exists == 1 { if exists == 1 {
res, err := MakeTheTokenInvalid(*claims, platform) res, err := MakeTheTokenInvalid(claims, platform)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,7 +112,7 @@ func ParseToken(tokensString string) (claims *Claims, err error) {
// or PC/Mobile validate success // or PC/Mobile validate success
// final check // final check
if exists == 1 { if exists == 1 {
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) res, err := MakeTheTokenInvalid(claims, Platform2class[claims.Platform])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,7 +123,7 @@ func ParseToken(tokensString string) (claims *Claims, err error) {
return claims, nil return claims, nil
} }
func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, error) { func MakeTheTokenInvalid(currentClaims *Claims, platformClass string) (bool, error) {
storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass) storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass)
if err != nil { if err != nil {
return false, err return false, err
@ -139,7 +133,7 @@ func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, erro
return false, err return false, err
} }
//if issue time less than redis token then make this token invalid //if issue time less than redis token then make this token invalid
if currentClaims.IssuedAt < storedRedisPlatformClaims.IssuedAt { if currentClaims.IssuedAt.Time.Unix() < storedRedisPlatformClaims.IssuedAt.Time.Unix() {
return true, TokenInvalid return true, TokenInvalid
} }
return false, nil return false, nil

@ -17,18 +17,18 @@ func Test_BuildClaims(t *testing.T) {
assert.Equal(t, claim.UID, uid, "uid should equal") assert.Equal(t, claim.UID, uid, "uid should equal")
assert.Equal(t, claim.Platform, platform, "platform should equal") assert.Equal(t, claim.Platform, platform, "platform should equal")
assert.Equal(t, claim.StandardClaims.ExpiresAt, int64(-1), "StandardClaims.ExpiresAt should be equal") assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(-1), "StandardClaims.ExpiresAt should be equal")
// time difference within 1s // time difference within 1s
assert.Equal(t, claim.StandardClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal") assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
assert.Equal(t, claim.StandardClaims.NotBefore, now, "StandardClaims.NotBefore should be equal") assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
ttl = int64(60) ttl = int64(60)
now = time.Now().Unix() now = time.Now().Unix()
claim = BuildClaims(uid, platform, ttl) claim = BuildClaims(uid, platform, ttl)
// time difference within 1s // time difference within 1s
assert.Equal(t, claim.StandardClaims.ExpiresAt, int64(60)+now, "StandardClaims.ExpiresAt should be equal") assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(60)+now, "StandardClaims.ExpiresAt should be equal")
assert.Equal(t, claim.StandardClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal") assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
assert.Equal(t, claim.StandardClaims.NotBefore, now, "StandardClaims.NotBefore should be equal") assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
} }
func Test_CreateToken(t *testing.T) { func Test_CreateToken(t *testing.T) {

Loading…
Cancel
Save