You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
8.6 KiB
281 lines
8.6 KiB
package inventory
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"entgo.io/ent/dialect/sql"
|
|
"github.com/cloudreve/Cloudreve/v4/ent"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/oauthclient"
|
|
"github.com/cloudreve/Cloudreve/v4/ent/oauthgrant"
|
|
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
|
"github.com/gofrs/uuid"
|
|
)
|
|
|
|
type (
|
|
OAuthClientClient interface {
|
|
TxOperator
|
|
// GetByGUID returns the OAuth client by its GUID (client_id).
|
|
GetByGUID(ctx context.Context, guid string) (*ent.OAuthClient, error)
|
|
// GetByGUIDWithGrants returns the OAuth client by its GUID (client_id) with the grants for the user.
|
|
GetByGUIDWithGrants(ctx context.Context, guid string, uid int) (*ent.OAuthClient, error)
|
|
// UpsertGrant creates or updates an OAuth grant for a user and client.
|
|
UpsertGrant(ctx context.Context, userID, clientID int, scopes []string) error
|
|
// UpdateGrantLastUsedAt updates the last used at for an OAuth grant for a user and client.
|
|
UpdateGrantLastUsedAt(ctx context.Context, userID, clientID int) error
|
|
// DeleteGrantByUserAndClientGUID deletes an OAuth grant for a user by the client GUID.
|
|
// Returns true if the grant was deleted, false if it was not found.
|
|
DeleteGrantByUserAndClientGUID(ctx context.Context, userID int, clientGUID string) (bool, error)
|
|
// List returns a paginated list of OAuth clients.
|
|
List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error)
|
|
// GetByID returns the OAuth client by its ID.
|
|
GetByID(ctx context.Context, id int) (*ent.OAuthClient, error)
|
|
// Create creates a new OAuth client.
|
|
Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error)
|
|
// Update updates an existing OAuth client.
|
|
Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error)
|
|
// Delete deletes an OAuth client by its ID.
|
|
Delete(ctx context.Context, id int) error
|
|
// CountGrants returns the number of grants for an OAuth client.
|
|
CountGrants(ctx context.Context, id int) (int, error)
|
|
// GetGrantsByUserID returns the OAuth grants for a user.
|
|
GetGrantsByUserID(ctx context.Context, userID int) ([]*ent.OAuthGrant, error)
|
|
}
|
|
|
|
ListOAuthClientArgs struct {
|
|
*PaginationArgs
|
|
Name string
|
|
IsEnabled *bool
|
|
}
|
|
|
|
ListOAuthClientResult struct {
|
|
*PaginationResults
|
|
Clients []*ent.OAuthClient
|
|
}
|
|
|
|
LoadOAuthGrantClient struct{}
|
|
)
|
|
|
|
func NewOAuthClientClient(client *ent.Client, dbType conf.DBType) OAuthClientClient {
|
|
return &oauthClientClient{
|
|
client: client,
|
|
maxSQlParam: sqlParamLimit(dbType),
|
|
}
|
|
}
|
|
|
|
type oauthClientClient struct {
|
|
client *ent.Client
|
|
maxSQlParam int
|
|
}
|
|
|
|
func (c *oauthClientClient) SetClient(newClient *ent.Client) TxOperator {
|
|
return &oauthClientClient{client: newClient, maxSQlParam: c.maxSQlParam}
|
|
}
|
|
|
|
func (c *oauthClientClient) GetClient() *ent.Client {
|
|
return c.client
|
|
}
|
|
|
|
func (c *oauthClientClient) GetByGUID(ctx context.Context, guid string) (*ent.OAuthClient, error) {
|
|
return c.client.OAuthClient.Query().
|
|
Where(oauthclient.GUID(guid), oauthclient.IsEnabled(true)).
|
|
First(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) GetByGUIDWithGrants(ctx context.Context, guid string, uid int) (*ent.OAuthClient, error) {
|
|
stm := c.client.OAuthClient.Query().
|
|
Where(oauthclient.GUID(guid), oauthclient.IsEnabled(true))
|
|
if uid > 0 {
|
|
stm.WithGrants(func(ogq *ent.OAuthGrantQuery) {
|
|
ogq.Where(oauthgrant.UserID(uid))
|
|
})
|
|
}
|
|
|
|
return stm.First(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) UpsertGrant(ctx context.Context, userID, clientID int, scopes []string) error {
|
|
return c.client.OAuthGrant.Create().
|
|
SetUserID(userID).
|
|
SetClientID(clientID).
|
|
SetScopes(scopes).
|
|
SetLastUsedAt(time.Now()).
|
|
OnConflict(
|
|
sql.ConflictColumns(oauthgrant.FieldUserID, oauthgrant.FieldClientID),
|
|
).
|
|
UpdateScopes().
|
|
UpdateLastUsedAt().
|
|
Exec(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) UpdateGrantLastUsedAt(ctx context.Context, userID, clientID int) error {
|
|
return c.client.OAuthGrant.Update().
|
|
Where(oauthgrant.UserID(userID), oauthgrant.ClientID(clientID)).
|
|
SetLastUsedAt(time.Now()).
|
|
Exec(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) GetGrantsByUserID(ctx context.Context, userID int) ([]*ent.OAuthGrant, error) {
|
|
return withOAuthGrantEagerLoadings(ctx, c.client.OAuthGrant.Query()).
|
|
Where(oauthgrant.UserID(userID)).
|
|
All(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) DeleteGrantByUserAndClientGUID(ctx context.Context, userID int, clientGUID string) (bool, error) {
|
|
// First, get the client by GUID to get its ID
|
|
client, err := c.client.OAuthClient.Query().
|
|
Where(oauthclient.GUID(clientGUID)).
|
|
First(ctx)
|
|
if err != nil {
|
|
if ent.IsNotFound(err) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("failed to get OAuth client: %w", err)
|
|
}
|
|
|
|
// Delete the grant for this user and client
|
|
deleted, err := c.client.OAuthGrant.Delete().
|
|
Where(oauthgrant.UserID(userID), oauthgrant.ClientID(client.ID)).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to delete OAuth grant: %w", err)
|
|
}
|
|
|
|
return deleted > 0, nil
|
|
}
|
|
|
|
func (c *oauthClientClient) List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error) {
|
|
query := c.client.OAuthClient.Query()
|
|
|
|
if args.Name != "" {
|
|
query.Where(oauthclient.NameContains(args.Name))
|
|
}
|
|
|
|
if args.IsEnabled != nil {
|
|
query.Where(oauthclient.IsEnabled(*args.IsEnabled))
|
|
}
|
|
|
|
pageSize := capPageSize(c.maxSQlParam, args.PageSize, 1)
|
|
|
|
total, err := query.Clone().Count(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to count OAuth clients: %w", err)
|
|
}
|
|
|
|
query.Order(getOAuthClientOrderOption(args)...)
|
|
|
|
clients, err := query.
|
|
Limit(pageSize).
|
|
Offset(args.Page * pageSize).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list OAuth clients: %w", err)
|
|
}
|
|
|
|
return &ListOAuthClientResult{
|
|
PaginationResults: &PaginationResults{
|
|
TotalItems: total,
|
|
Page: args.Page,
|
|
PageSize: pageSize,
|
|
},
|
|
Clients: clients,
|
|
}, nil
|
|
}
|
|
|
|
func (c *oauthClientClient) GetByID(ctx context.Context, id int) (*ent.OAuthClient, error) {
|
|
return c.client.OAuthClient.Query().
|
|
Where(oauthclient.ID(id)).
|
|
First(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) {
|
|
if client.Props == nil {
|
|
client.Props = &types.OAuthClientProps{}
|
|
}
|
|
|
|
// Generate a new GUID and secret if not provided
|
|
if client.GUID == "" {
|
|
client.GUID = uuid.Must(uuid.NewV4()).String()
|
|
}
|
|
if client.Secret == "" {
|
|
client.Secret = util.RandStringRunes(32)
|
|
}
|
|
|
|
return c.client.OAuthClient.Create().
|
|
SetGUID(client.GUID).
|
|
SetSecret(client.Secret).
|
|
SetName(client.Name).
|
|
SetHomepageURL(client.HomepageURL).
|
|
SetRedirectUris(client.RedirectUris).
|
|
SetScopes(client.Scopes).
|
|
SetProps(client.Props).
|
|
SetIsEnabled(client.IsEnabled).
|
|
Save(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) {
|
|
if client.Props == nil {
|
|
client.Props = &types.OAuthClientProps{}
|
|
}
|
|
|
|
update := c.client.OAuthClient.UpdateOneID(client.ID).
|
|
SetName(client.Name).
|
|
SetHomepageURL(client.HomepageURL).
|
|
SetRedirectUris(client.RedirectUris).
|
|
SetScopes(client.Scopes).
|
|
SetProps(client.Props).
|
|
SetIsEnabled(client.IsEnabled)
|
|
|
|
// Only update secret if provided (non-empty)
|
|
if client.Secret != "" {
|
|
update.SetSecret(client.Secret)
|
|
}
|
|
|
|
return update.Save(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) Delete(ctx context.Context, id int) error {
|
|
// Delete all grants first
|
|
_, err := c.client.OAuthGrant.Delete().
|
|
Where(oauthgrant.ClientID(id)).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete OAuth grants: %w", err)
|
|
}
|
|
|
|
// Delete the client
|
|
return c.client.OAuthClient.DeleteOneID(id).Exec(ctx)
|
|
}
|
|
|
|
func (c *oauthClientClient) CountGrants(ctx context.Context, id int) (int, error) {
|
|
return c.client.OAuthGrant.Query().
|
|
Where(oauthgrant.ClientID(id)).
|
|
Count(ctx)
|
|
}
|
|
|
|
func getOAuthClientOrderOption(args *ListOAuthClientArgs) []oauthclient.OrderOption {
|
|
orderTerm := getOrderTerm(args.Order)
|
|
switch args.OrderBy {
|
|
case oauthclient.FieldName:
|
|
return []oauthclient.OrderOption{oauthclient.ByName(orderTerm), oauthclient.ByID(orderTerm)}
|
|
case oauthclient.FieldCreatedAt:
|
|
return []oauthclient.OrderOption{oauthclient.ByCreatedAt(orderTerm), oauthclient.ByID(orderTerm)}
|
|
case oauthclient.FieldIsEnabled:
|
|
return []oauthclient.OrderOption{oauthclient.ByIsEnabled(orderTerm), oauthclient.ByID(orderTerm)}
|
|
default:
|
|
return []oauthclient.OrderOption{oauthclient.ByID(orderTerm)}
|
|
}
|
|
}
|
|
|
|
func withOAuthGrantEagerLoadings(ctx context.Context, q *ent.OAuthGrantQuery) *ent.OAuthGrantQuery {
|
|
if v, ok := ctx.Value(LoadOAuthGrantClient{}).(bool); ok && v {
|
|
q.WithClient(func(ocq *ent.OAuthClientQuery) {
|
|
})
|
|
}
|
|
|
|
return q
|
|
}
|