You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cloudreve/inventory/oauth_client.go

281 lines
8.6 KiB

package inventory
import (
"context"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/oauthclient"
"github.com/cloudreve/Cloudreve/v4/ent/oauthgrant"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gofrs/uuid"
)
type (
OAuthClientClient interface {
TxOperator
// GetByGUID returns the OAuth client by its GUID (client_id).
GetByGUID(ctx context.Context, guid string) (*ent.OAuthClient, error)
// GetByGUIDWithGrants returns the OAuth client by its GUID (client_id) with the grants for the user.
GetByGUIDWithGrants(ctx context.Context, guid string, uid int) (*ent.OAuthClient, error)
// UpsertGrant creates or updates an OAuth grant for a user and client.
UpsertGrant(ctx context.Context, userID, clientID int, scopes []string) error
// UpdateGrantLastUsedAt updates the last used at for an OAuth grant for a user and client.
UpdateGrantLastUsedAt(ctx context.Context, userID, clientID int) error
// DeleteGrantByUserAndClientGUID deletes an OAuth grant for a user by the client GUID.
// Returns true if the grant was deleted, false if it was not found.
DeleteGrantByUserAndClientGUID(ctx context.Context, userID int, clientGUID string) (bool, error)
// List returns a paginated list of OAuth clients.
List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error)
// GetByID returns the OAuth client by its ID.
GetByID(ctx context.Context, id int) (*ent.OAuthClient, error)
// Create creates a new OAuth client.
Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error)
// Update updates an existing OAuth client.
Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error)
// Delete deletes an OAuth client by its ID.
Delete(ctx context.Context, id int) error
// CountGrants returns the number of grants for an OAuth client.
CountGrants(ctx context.Context, id int) (int, error)
// GetGrantsByUserID returns the OAuth grants for a user.
GetGrantsByUserID(ctx context.Context, userID int) ([]*ent.OAuthGrant, error)
}
ListOAuthClientArgs struct {
*PaginationArgs
Name string
IsEnabled *bool
}
ListOAuthClientResult struct {
*PaginationResults
Clients []*ent.OAuthClient
}
LoadOAuthGrantClient struct{}
)
func NewOAuthClientClient(client *ent.Client, dbType conf.DBType) OAuthClientClient {
return &oauthClientClient{
client: client,
maxSQlParam: sqlParamLimit(dbType),
}
}
type oauthClientClient struct {
client *ent.Client
maxSQlParam int
}
func (c *oauthClientClient) SetClient(newClient *ent.Client) TxOperator {
return &oauthClientClient{client: newClient, maxSQlParam: c.maxSQlParam}
}
func (c *oauthClientClient) GetClient() *ent.Client {
return c.client
}
func (c *oauthClientClient) GetByGUID(ctx context.Context, guid string) (*ent.OAuthClient, error) {
return c.client.OAuthClient.Query().
Where(oauthclient.GUID(guid), oauthclient.IsEnabled(true)).
First(ctx)
}
func (c *oauthClientClient) GetByGUIDWithGrants(ctx context.Context, guid string, uid int) (*ent.OAuthClient, error) {
stm := c.client.OAuthClient.Query().
Where(oauthclient.GUID(guid), oauthclient.IsEnabled(true))
if uid > 0 {
stm.WithGrants(func(ogq *ent.OAuthGrantQuery) {
ogq.Where(oauthgrant.UserID(uid))
})
}
return stm.First(ctx)
}
func (c *oauthClientClient) UpsertGrant(ctx context.Context, userID, clientID int, scopes []string) error {
return c.client.OAuthGrant.Create().
SetUserID(userID).
SetClientID(clientID).
SetScopes(scopes).
SetLastUsedAt(time.Now()).
OnConflict(
sql.ConflictColumns(oauthgrant.FieldUserID, oauthgrant.FieldClientID),
).
UpdateScopes().
UpdateLastUsedAt().
Exec(ctx)
}
func (c *oauthClientClient) UpdateGrantLastUsedAt(ctx context.Context, userID, clientID int) error {
return c.client.OAuthGrant.Update().
Where(oauthgrant.UserID(userID), oauthgrant.ClientID(clientID)).
SetLastUsedAt(time.Now()).
Exec(ctx)
}
func (c *oauthClientClient) GetGrantsByUserID(ctx context.Context, userID int) ([]*ent.OAuthGrant, error) {
return withOAuthGrantEagerLoadings(ctx, c.client.OAuthGrant.Query()).
Where(oauthgrant.UserID(userID)).
All(ctx)
}
func (c *oauthClientClient) DeleteGrantByUserAndClientGUID(ctx context.Context, userID int, clientGUID string) (bool, error) {
// First, get the client by GUID to get its ID
client, err := c.client.OAuthClient.Query().
Where(oauthclient.GUID(clientGUID)).
First(ctx)
if err != nil {
if ent.IsNotFound(err) {
return false, nil
}
return false, fmt.Errorf("failed to get OAuth client: %w", err)
}
// Delete the grant for this user and client
deleted, err := c.client.OAuthGrant.Delete().
Where(oauthgrant.UserID(userID), oauthgrant.ClientID(client.ID)).
Exec(ctx)
if err != nil {
return false, fmt.Errorf("failed to delete OAuth grant: %w", err)
}
return deleted > 0, nil
}
func (c *oauthClientClient) List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error) {
query := c.client.OAuthClient.Query()
if args.Name != "" {
query.Where(oauthclient.NameContains(args.Name))
}
if args.IsEnabled != nil {
query.Where(oauthclient.IsEnabled(*args.IsEnabled))
}
pageSize := capPageSize(c.maxSQlParam, args.PageSize, 1)
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, fmt.Errorf("failed to count OAuth clients: %w", err)
}
query.Order(getOAuthClientOrderOption(args)...)
clients, err := query.
Limit(pageSize).
Offset(args.Page * pageSize).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list OAuth clients: %w", err)
}
return &ListOAuthClientResult{
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: pageSize,
},
Clients: clients,
}, nil
}
func (c *oauthClientClient) GetByID(ctx context.Context, id int) (*ent.OAuthClient, error) {
return c.client.OAuthClient.Query().
Where(oauthclient.ID(id)).
First(ctx)
}
func (c *oauthClientClient) Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) {
if client.Props == nil {
client.Props = &types.OAuthClientProps{}
}
// Generate a new GUID and secret if not provided
if client.GUID == "" {
client.GUID = uuid.Must(uuid.NewV4()).String()
}
if client.Secret == "" {
client.Secret = util.RandStringRunes(32)
}
return c.client.OAuthClient.Create().
SetGUID(client.GUID).
SetSecret(client.Secret).
SetName(client.Name).
SetHomepageURL(client.HomepageURL).
SetRedirectUris(client.RedirectUris).
SetScopes(client.Scopes).
SetProps(client.Props).
SetIsEnabled(client.IsEnabled).
Save(ctx)
}
func (c *oauthClientClient) Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) {
if client.Props == nil {
client.Props = &types.OAuthClientProps{}
}
update := c.client.OAuthClient.UpdateOneID(client.ID).
SetName(client.Name).
SetHomepageURL(client.HomepageURL).
SetRedirectUris(client.RedirectUris).
SetScopes(client.Scopes).
SetProps(client.Props).
SetIsEnabled(client.IsEnabled)
// Only update secret if provided (non-empty)
if client.Secret != "" {
update.SetSecret(client.Secret)
}
return update.Save(ctx)
}
func (c *oauthClientClient) Delete(ctx context.Context, id int) error {
// Delete all grants first
_, err := c.client.OAuthGrant.Delete().
Where(oauthgrant.ClientID(id)).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete OAuth grants: %w", err)
}
// Delete the client
return c.client.OAuthClient.DeleteOneID(id).Exec(ctx)
}
func (c *oauthClientClient) CountGrants(ctx context.Context, id int) (int, error) {
return c.client.OAuthGrant.Query().
Where(oauthgrant.ClientID(id)).
Count(ctx)
}
func getOAuthClientOrderOption(args *ListOAuthClientArgs) []oauthclient.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case oauthclient.FieldName:
return []oauthclient.OrderOption{oauthclient.ByName(orderTerm), oauthclient.ByID(orderTerm)}
case oauthclient.FieldCreatedAt:
return []oauthclient.OrderOption{oauthclient.ByCreatedAt(orderTerm), oauthclient.ByID(orderTerm)}
case oauthclient.FieldIsEnabled:
return []oauthclient.OrderOption{oauthclient.ByIsEnabled(orderTerm), oauthclient.ByID(orderTerm)}
default:
return []oauthclient.OrderOption{oauthclient.ByID(orderTerm)}
}
}
func withOAuthGrantEagerLoadings(ctx context.Context, q *ent.OAuthGrantQuery) *ent.OAuthGrantQuery {
if v, ok := ctx.Value(LoadOAuthGrantClient{}).(bool); ok && v {
q.WithClient(func(ocq *ent.OAuthClientQuery) {
})
}
return q
}