package auth import ( "bytes" "context" "crypto/sha256" "errors" "fmt" "strings" "time" "github.com/cloudreve/Cloudreve/v4/ent" "github.com/cloudreve/Cloudreve/v4/inventory" "github.com/cloudreve/Cloudreve/v4/pkg/cache" "github.com/cloudreve/Cloudreve/v4/pkg/hashid" "github.com/cloudreve/Cloudreve/v4/pkg/logging" "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/cloudreve/Cloudreve/v4/pkg/setting" "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v5" ) type TokenAuth interface { // Issue issues a new pair of credentials for the given user. Issue(ctx context.Context, args *IssueTokenArgs) (*Token, error) // VerifyAndRetrieveUser verifies the given token and inject the user into current context. // Returns if upper caller should continue process other session provider. VerifyAndRetrieveUser(c *gin.Context) (bool, error) // Refresh refreshes the given refresh token and returns a new pair of credentials. Refresh(ctx context.Context, refreshToken string) (*Token, error) // Claims parses the given token string and returns the claims. Claims(ctx context.Context, tokenStr string) (*Claims, error) } type IssueTokenArgs struct { User *ent.User RootTokenID *uuid.UUID ClientID string Scopes []string RefreshTTLOverride time.Duration } // Token stores token pair for authentication type Token struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` AccessExpires time.Time `json:"access_expires"` RefreshExpires time.Time `json:"refresh_expires"` UID int `json:"-"` } type ( TokenType string TokenIDContextKey struct{} ScopeContextKey struct{} ) var ( TokenTypeAccess = TokenType("access") TokenTypeRefresh = TokenType("refresh") ErrInvalidRefreshToken = errors.New("invalid refresh token") ErrUserNotFound = errors.New("user not found") ) const ( AuthorizationHeader = "Authorization" TokenHeaderPrefix = "Bearer " RevokeTokenPrefix = "jwt_revoke_" ) type Claims struct { TokenType TokenType `json:"token_type"` jwt.RegisteredClaims StateHash []byte `json:"state_hash,omitempty"` RootTokenID *uuid.UUID `json:"root_token_id,omitempty"` Scopes []string `json:"scopes,omitempty"` ClientID string `json:"client_id,omitempty"` } // NewTokenAuth creates a new token based auth provider. func NewTokenAuth(idEncoder hashid.Encoder, s setting.Provider, secret []byte, userClient inventory.UserClient, l logging.Logger, kv cache.Driver, oAuthClient inventory.OAuthClientClient) TokenAuth { return &tokenAuth{ idEncoder: idEncoder, s: s, secret: secret, userClient: userClient, l: l, kv: kv, oAuthClient: oAuthClient, } } type tokenAuth struct { l logging.Logger idEncoder hashid.Encoder s setting.Provider secret []byte userClient inventory.UserClient oAuthClient inventory.OAuthClientClient kv cache.Driver } func (t *tokenAuth) Claims(ctx context.Context, tokenStr string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { return t.secret, nil }) if err != nil { return nil, fmt.Errorf("invalid token: %w", err) } claims, ok := token.Claims.(*Claims) if !ok { return nil, fmt.Errorf("invalid token claims") } return claims, nil } func (t *tokenAuth) Refresh(ctx context.Context, refreshToken string) (*Token, error) { token, err := jwt.ParseWithClaims(refreshToken, &Claims{}, func(token *jwt.Token) (interface{}, error) { return t.secret, nil }) if err != nil { return nil, fmt.Errorf("invalid refresh token: %w", err) } claims, ok := token.Claims.(*Claims) if !ok || claims.TokenType != TokenTypeRefresh { return nil, ErrInvalidRefreshToken } uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID) if err != nil { return nil, ErrUserNotFound } expectedUser, err := t.userClient.GetActiveByID(ctx, uid) if err != nil { return nil, ErrUserNotFound } // Check if user changed password or revoked session expectedHash := t.hashUserState(ctx, expectedUser) if !bytes.Equal(claims.StateHash, expectedHash[:]) { return nil, ErrInvalidRefreshToken } // Check if root token is revoked if claims.RootTokenID == nil { return nil, ErrInvalidRefreshToken } _, ok = t.kv.Get(fmt.Sprintf("%s%s", RevokeTokenPrefix, claims.RootTokenID.String())) if ok { return nil, ErrInvalidRefreshToken } // If token issued for an OAuth client, check if the client is still valid refreshTTLOverride := time.Duration(0) if claims.ClientID != "" { client, err := t.oAuthClient.GetByGUIDWithGrants(ctx, claims.ClientID, expectedUser.ID) if err != nil || len(client.Edges.Grants) == 0 { return nil, ErrInvalidRefreshToken } // Consented scopes must be a subset of the client's scopes if !ValidateScopes(claims.Scopes, client.Edges.Grants[0].Scopes) { return nil, ErrInvalidRefreshToken } // Update last used at for the grant if err := t.oAuthClient.UpdateGrantLastUsedAt(ctx, expectedUser.ID, client.ID); err != nil { return nil, ErrInvalidRefreshToken } if client.Props != nil { refreshTTLOverride = time.Duration(client.Props.RefreshTokenTTL) * time.Second } } return t.Issue(ctx, &IssueTokenArgs{ User: expectedUser, RootTokenID: claims.RootTokenID, Scopes: claims.Scopes, ClientID: claims.ClientID, RefreshTTLOverride: refreshTTLOverride, }) } func (t *tokenAuth) VerifyAndRetrieveUser(c *gin.Context) (bool, error) { headerVal := c.GetHeader(AuthorizationHeader) if strings.HasPrefix(headerVal, TokenHeaderPrefixCr) { // This is an HMAC auth header, skip JWT verification return false, nil } tokenString := strings.TrimPrefix(headerVal, TokenHeaderPrefix) if tokenString == "" { return true, nil } token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return t.secret, nil }) if err != nil { t.l.Warning("Failed to parse jwt token: %s", err) return false, nil } claims, ok := token.Claims.(*Claims) if !ok || claims.TokenType != TokenTypeAccess { return false, serializer.NewError(serializer.CodeCredentialInvalid, "Invalid token type", nil) } uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID) if err != nil { return false, serializer.NewError(serializer.CodeNotFound, "User not found", err) } util.WithValue(c, inventory.UserIDCtx{}, uid) if claims.ClientID != "" { util.WithValue(c, ScopeContextKey{}, claims.Scopes) } return false, nil } func (t *tokenAuth) Issue(ctx context.Context, args *IssueTokenArgs) (*Token, error) { u := args.User rootTokenID := args.RootTokenID uidEncoded := hashid.EncodeUserID(t.idEncoder, u.ID) tokenSettings := t.s.TokenAuth(ctx) issueDate := time.Now() accessTokenExpired := time.Now().Add(tokenSettings.AccessTokenTTL) refreshTokenExpired := time.Now().Add(tokenSettings.RefreshTokenTTL) if args.RefreshTTLOverride > 0 { refreshTokenExpired = time.Now().Add(args.RefreshTTLOverride) } if rootTokenID == nil { newRootTokenID := uuid.Must(uuid.NewV4()) rootTokenID = &newRootTokenID } accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ TokenType: TokenTypeAccess, RegisteredClaims: jwt.RegisteredClaims{ Subject: uidEncoded, NotBefore: jwt.NewNumericDate(issueDate), ExpiresAt: jwt.NewNumericDate(accessTokenExpired), }, Scopes: args.Scopes, }).SignedString(t.secret) if err != nil { return nil, fmt.Errorf("faield to sign access token: %w", err) } userHash := t.hashUserState(ctx, u) refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ TokenType: TokenTypeRefresh, RootTokenID: rootTokenID, RegisteredClaims: jwt.RegisteredClaims{ Subject: uidEncoded, NotBefore: jwt.NewNumericDate(issueDate), ExpiresAt: jwt.NewNumericDate(refreshTokenExpired), }, Scopes: args.Scopes, ClientID: args.ClientID, StateHash: userHash[:], }).SignedString(t.secret) if err != nil { return nil, fmt.Errorf("faield to sign refresh token: %w", err) } return &Token{ AccessToken: accessToken, RefreshToken: refreshToken, AccessExpires: accessTokenExpired, RefreshExpires: refreshTokenExpired, UID: u.ID, }, nil } // hashUserState returns a hash string for user state for critical fields, it is used // to detect refresh token revocation after user changed password. func (t *tokenAuth) hashUserState(ctx context.Context, u *ent.User) [32]byte { return sha256.Sum256([]byte(fmt.Sprintf("%s/%s/%s", u.Email, u.Password, t.s.SiteBasic(ctx).ID))) } // ValidateScopes checks if all requested scopes are a subset of the allowed scopes. // Returns true if all requested scopes are valid, false otherwise. func ValidateScopes(requestedScopes, allowedScopes []string) bool { allowed := make(map[string]struct{}, len(allowedScopes)) for _, scope := range allowedScopes { allowed[scope] = struct{}{} } for _, scope := range requestedScopes { if _, ok := allowed[scope]; !ok { return false } } return true } func GetScopesFromContext(ctx context.Context) (bool, []string) { scopes, ok := ctx.Value(ScopeContextKey{}).([]string) if !ok { return false, nil } return true, scopes } func CheckScope(c *gin.Context, requiredScopes ...string) error { hasScopes, tokenScopes := GetScopesFromContext(c) if !hasScopes { return nil } // Build a set of token scopes including implicit read permissions from write scopes scopeSet := make(map[string]struct{}, len(tokenScopes)*2) for _, scope := range tokenScopes { scopeSet[scope] = struct{}{} // If scope is "xxx.Write", also grant "xxx.Read" if resource, ok := extractWriteResource(scope); ok { scopeSet[resource+".Read"] = struct{}{} } } // Check if all required scopes are present for _, required := range requiredScopes { if _, ok := scopeSet[required]; !ok { return serializer.NewError(serializer.CodeInsufficientScope, "Insufficient scope: "+required, nil) } } return nil } // extractWriteResource extracts the resource name from a write scope. // For example, "File.Write" returns ("File", true), "File.Read" returns ("", false). func extractWriteResource(scope string) (string, bool) { const writeSuffix = ".Write" if len(scope) > len(writeSuffix) && scope[len(scope)-len(writeSuffix):] == writeSuffix { return scope[:len(scope)-len(writeSuffix)], true } return "", false }