diff --git a/application/dependency/dependency.go b/application/dependency/dependency.go index 1bbfe8a6..adb6421b 100644 --- a/application/dependency/dependency.go +++ b/application/dependency/dependency.go @@ -491,7 +491,7 @@ func (d *dependency) OAuthClientClient() inventory.OAuthClientClient { return d.oAuthClient } - return inventory.NewOAuthClientClient(d.DBClient()) + return inventory.NewOAuthClientClient(d.DBClient(), d.ConfigProvider().Database().Type) } func (d *dependency) MimeDetector(ctx context.Context) mime.MimeDetector { diff --git a/assets b/assets index ee51bb64..c0f3e502 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit ee51bb6483d69f99c86db4c370ca742e486cfad9 +Subproject commit c0f3e50207b5783220fafc26a0c9885aadd38af1 diff --git a/inventory/oauth_client.go b/inventory/oauth_client.go index 56065bdc..c9f6b3c0 100644 --- a/inventory/oauth_client.go +++ b/inventory/oauth_client.go @@ -2,12 +2,17 @@ package inventory import ( "context" + "fmt" "time" "entgo.io/ent/dialect/sql" "github.com/cloudreve/Cloudreve/v4/ent" "github.com/cloudreve/Cloudreve/v4/ent/oauthclient" "github.com/cloudreve/Cloudreve/v4/ent/oauthgrant" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" ) type ( @@ -21,21 +26,46 @@ type ( UpsertGrant(ctx context.Context, userID, clientID int, scopes []string) error // UpdateGrantLastUsedAt updates the last used at for an OAuth grant for a user and client. UpdateGrantLastUsedAt(ctx context.Context, userID, clientID int) error + // List returns a paginated list of OAuth clients. + List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error) + // GetByID returns the OAuth client by its ID. + GetByID(ctx context.Context, id int) (*ent.OAuthClient, error) + // Create creates a new OAuth client. + Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) + // Update updates an existing OAuth client. + Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) + // Delete deletes an OAuth client by its ID. + Delete(ctx context.Context, id int) error + // CountGrants returns the number of grants for an OAuth client. + CountGrants(ctx context.Context, id int) (int, error) + } + + ListOAuthClientArgs struct { + *PaginationArgs + Name string + IsEnabled *bool + } + + ListOAuthClientResult struct { + *PaginationResults + Clients []*ent.OAuthClient } ) -func NewOAuthClientClient(client *ent.Client) OAuthClientClient { +func NewOAuthClientClient(client *ent.Client, dbType conf.DBType) OAuthClientClient { return &oauthClientClient{ - client: client, + client: client, + maxSQlParam: sqlParamLimit(dbType), } } type oauthClientClient struct { - client *ent.Client + client *ent.Client + maxSQlParam int } func (c *oauthClientClient) SetClient(newClient *ent.Client) TxOperator { - return &oauthClientClient{client: newClient} + return &oauthClientClient{client: newClient, maxSQlParam: c.maxSQlParam} } func (c *oauthClientClient) GetClient() *ent.Client { @@ -80,3 +110,126 @@ func (c *oauthClientClient) UpdateGrantLastUsedAt(ctx context.Context, userID, c SetLastUsedAt(time.Now()). Exec(ctx) } + +func (c *oauthClientClient) List(ctx context.Context, args *ListOAuthClientArgs) (*ListOAuthClientResult, error) { + query := c.client.OAuthClient.Query() + + if args.Name != "" { + query.Where(oauthclient.NameContains(args.Name)) + } + + if args.IsEnabled != nil { + query.Where(oauthclient.IsEnabled(*args.IsEnabled)) + } + + pageSize := capPageSize(c.maxSQlParam, args.PageSize, 1) + + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to count OAuth clients: %w", err) + } + + query.Order(getOAuthClientOrderOption(args)...) + + clients, err := query. + Limit(pageSize). + Offset(args.Page * pageSize). + All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list OAuth clients: %w", err) + } + + return &ListOAuthClientResult{ + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: pageSize, + }, + Clients: clients, + }, nil +} + +func (c *oauthClientClient) GetByID(ctx context.Context, id int) (*ent.OAuthClient, error) { + return c.client.OAuthClient.Query(). + Where(oauthclient.ID(id)). + First(ctx) +} + +func (c *oauthClientClient) Create(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) { + if client.Props == nil { + client.Props = &types.OAuthClientProps{} + } + + // Generate a new GUID and secret if not provided + if client.GUID == "" { + client.GUID = uuid.Must(uuid.NewV4()).String() + } + if client.Secret == "" { + client.Secret = util.RandStringRunes(32) + } + + return c.client.OAuthClient.Create(). + SetGUID(client.GUID). + SetSecret(client.Secret). + SetName(client.Name). + SetHomepageURL(client.HomepageURL). + SetRedirectUris(client.RedirectUris). + SetScopes(client.Scopes). + SetProps(client.Props). + SetIsEnabled(client.IsEnabled). + Save(ctx) +} + +func (c *oauthClientClient) Update(ctx context.Context, client *ent.OAuthClient) (*ent.OAuthClient, error) { + if client.Props == nil { + client.Props = &types.OAuthClientProps{} + } + + update := c.client.OAuthClient.UpdateOneID(client.ID). + SetName(client.Name). + SetHomepageURL(client.HomepageURL). + SetRedirectUris(client.RedirectUris). + SetScopes(client.Scopes). + SetProps(client.Props). + SetIsEnabled(client.IsEnabled) + + // Only update secret if provided (non-empty) + if client.Secret != "" { + update.SetSecret(client.Secret) + } + + return update.Save(ctx) +} + +func (c *oauthClientClient) Delete(ctx context.Context, id int) error { + // Delete all grants first + _, err := c.client.OAuthGrant.Delete(). + Where(oauthgrant.ClientID(id)). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to delete OAuth grants: %w", err) + } + + // Delete the client + return c.client.OAuthClient.DeleteOneID(id).Exec(ctx) +} + +func (c *oauthClientClient) CountGrants(ctx context.Context, id int) (int, error) { + return c.client.OAuthGrant.Query(). + Where(oauthgrant.ClientID(id)). + Count(ctx) +} + +func getOAuthClientOrderOption(args *ListOAuthClientArgs) []oauthclient.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case oauthclient.FieldName: + return []oauthclient.OrderOption{oauthclient.ByName(orderTerm), oauthclient.ByID(orderTerm)} + case oauthclient.FieldCreatedAt: + return []oauthclient.OrderOption{oauthclient.ByCreatedAt(orderTerm), oauthclient.ByID(orderTerm)} + case oauthclient.FieldIsEnabled: + return []oauthclient.OrderOption{oauthclient.ByIsEnabled(orderTerm), oauthclient.ByID(orderTerm)} + default: + return []oauthclient.OrderOption{oauthclient.ByID(orderTerm)} + } +} diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index 12e7ab50..1791f78e 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -597,3 +597,69 @@ func AdminCalibrateStorage(c *gin.Context) { } c.JSON(200, serializer.Response{Data: res}) } + +// AdminListOAuthClients lists OAuth clients +func AdminListOAuthClients(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.OAuthClients(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +// AdminGetOAuthClient gets an OAuth client by ID +func AdminGetOAuthClient(c *gin.Context) { + service := ParametersFromContext[*admin.SingleOAuthClientService](c, admin.SingleOAuthClientParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +// AdminCreateOAuthClient creates a new OAuth client +func AdminCreateOAuthClient(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertOAuthClientService](c, admin.UpsertOAuthClientParamCtx{}) + res, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +// AdminUpdateOAuthClient updates an OAuth client +func AdminUpdateOAuthClient(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertOAuthClientService](c, admin.UpsertOAuthClientParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +// AdminDeleteOAuthClient deletes an OAuth client +func AdminDeleteOAuthClient(c *gin.Context) { + service := ParametersFromContext[*admin.SingleOAuthClientService](c, admin.SingleOAuthClientParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{}) +} + +// AdminBatchDeleteOAuthClient batch deletes OAuth clients +func AdminBatchDeleteOAuthClient(c *gin.Context) { + service := ParametersFromContext[*admin.BatchOAuthClientService](c, admin.BatchOAuthClientParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{}) +} diff --git a/routers/router.go b/routers/router.go index 54e4bd00..88499cdc 100644 --- a/routers/router.go +++ b/routers/router.go @@ -334,6 +334,7 @@ func initMasterRouter(dep dependency.Dep) *gin.Engine { controllers.ExchangeToken, ) oauthRouter.GET("userinfo", + middleware.LoginRequired(), controllers.FromQuery[oauth.UserInfoService](oauth.UserInfoParamCtx{}), controllers.OpenIDUserInfo, ) @@ -1038,6 +1039,44 @@ func initMasterRouter(dep dependency.Dep) *gin.Engine { ) } + oauthClient := admin.Group("oauthClient") + { + // List OAuth clients + oauthClient.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListOAuthClients, + ) + // Get OAuth client + oauthClient.GET(":id", + controllers.FromUri[adminsvc.SingleOAuthClientService](adminsvc.SingleOAuthClientParamCtx{}), + controllers.AdminGetOAuthClient, + ) + // Create OAuth client + oauthClient.PUT("", + middleware.RequiredScopes(types.ScopeAdminWrite), + controllers.FromJSON[adminsvc.UpsertOAuthClientService](adminsvc.UpsertOAuthClientParamCtx{}), + controllers.AdminCreateOAuthClient, + ) + // Update OAuth client + oauthClient.PUT(":id", + middleware.RequiredScopes(types.ScopeAdminWrite), + controllers.FromJSON[adminsvc.UpsertOAuthClientService](adminsvc.UpsertOAuthClientParamCtx{}), + controllers.AdminUpdateOAuthClient, + ) + // Delete OAuth client + oauthClient.DELETE(":id", + middleware.RequiredScopes(types.ScopeAdminWrite), + controllers.FromUri[adminsvc.SingleOAuthClientService](adminsvc.SingleOAuthClientParamCtx{}), + controllers.AdminDeleteOAuthClient, + ) + // Batch delete OAuth clients + oauthClient.POST("batch/delete", + middleware.RequiredScopes(types.ScopeAdminWrite), + controllers.FromJSON[adminsvc.BatchOAuthClientService](adminsvc.BatchOAuthClientParamCtx{}), + controllers.AdminBatchDeleteOAuthClient, + ) + } + user := admin.Group("user") { // 列出用户 diff --git a/service/admin/oauth_client.go b/service/admin/oauth_client.go new file mode 100644 index 00000000..cbcbf864 --- /dev/null +++ b/service/admin/oauth_client.go @@ -0,0 +1,191 @@ +package admin + +import ( + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" + "github.com/samber/lo" +) + +// System OAuth client GUIDs that cannot be deleted +var systemOAuthClientGUIDs = []string{ + inventory.OAuthClientDesktopGUID, + inventory.OAuthClientiOSGUID, +} + +type ( + SingleOAuthClientService struct { + ID int `uri:"id" json:"id" binding:"required"` + } + SingleOAuthClientParamCtx struct{} +) + +type ( + UpsertOAuthClientService struct { + Client *ent.OAuthClient `json:"client" binding:"required"` + } + UpsertOAuthClientParamCtx struct{} +) + +type ( + BatchOAuthClientService struct { + IDs []int `json:"ids" binding:"required"` + } + BatchOAuthClientParamCtx struct{} +) + +// OAuthClients lists OAuth clients with pagination +func (s *AdminListService) OAuthClients(c *gin.Context) (*ListOAuthClientResponse, error) { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + var isEnabled *bool + if enabledStr, ok := s.Conditions["is_enabled"]; ok { + enabled := enabledStr == "true" + isEnabled = &enabled + } + + res, err := oauthClient.List(c, &inventory.ListOAuthClientArgs{ + PaginationArgs: &inventory.PaginationArgs{ + Page: s.Page - 1, + PageSize: s.PageSize, + OrderBy: s.OrderBy, + Order: inventory.OrderDirection(s.OrderDirection), + }, + Name: s.Searches["name"], + IsEnabled: isEnabled, + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list OAuth clients", err) + } + + clients := lo.Map(res.Clients, func(client *ent.OAuthClient, _ int) GetOAuthClientResponse { + return GetOAuthClientResponse{ + OAuthClient: client, + IsSystem: lo.Contains(systemOAuthClientGUIDs, client.GUID), + } + }) + + return &ListOAuthClientResponse{ + Pagination: res.PaginationResults, + Clients: clients, + }, nil +} + +func (s *SingleOAuthClientService) Get(c *gin.Context) (*GetOAuthClientResponse, error) { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + client, err := oauthClient.GetByID(c, s.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "OAuth client not found", err) + } + + res := &GetOAuthClientResponse{ + OAuthClient: client, + IsSystem: lo.Contains(systemOAuthClientGUIDs, client.GUID), + } + + // Count grants + grants, err := oauthClient.CountGrants(c, s.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to count grants", err) + } + res.TotalGrants = grants + + return res, nil +} + +func (s *UpsertOAuthClientService) Create(c *gin.Context) (*GetOAuthClientResponse, error) { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + if s.Client.ID > 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID must be 0 for creating new OAuth client", nil) + } + + client, err := oauthClient.Create(c, s.Client) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create OAuth client", err) + } + + service := &SingleOAuthClientService{ID: client.ID} + return service.Get(c) +} + +func (s *UpsertOAuthClientService) Update(c *gin.Context) (*GetOAuthClientResponse, error) { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + if s.Client.ID == 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID is required", nil) + } + + // Check if this is a system client + existing, err := oauthClient.GetByID(c, s.Client.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "OAuth client not found", err) + } + + // System clients cannot change GUID + if lo.Contains(systemOAuthClientGUIDs, existing.GUID) { + s.Client.GUID = existing.GUID + } + + _, err = oauthClient.Update(c, s.Client) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update OAuth client", err) + } + + service := &SingleOAuthClientService{ID: s.Client.ID} + return service.Get(c) +} + +func (s *SingleOAuthClientService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + // Check if client exists + client, err := oauthClient.GetByID(c, s.ID) + if err != nil { + return serializer.NewError(serializer.CodeNotFound, "OAuth client not found", err) + } + + // Check if this is a system client + if lo.Contains(systemOAuthClientGUIDs, client.GUID) { + return serializer.NewError(serializer.CodeInvalidActionOnSystemGroup, "Cannot delete system OAuth client", nil) + } + + err = oauthClient.Delete(c, s.ID) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete OAuth client", err) + } + + return nil +} + +func (s *BatchOAuthClientService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + oauthClient := dep.OAuthClientClient() + + for _, id := range s.IDs { + // Check if client exists + client, err := oauthClient.GetByID(c, id) + if err != nil { + continue // Skip non-existent clients + } + + // Check if this is a system client + if lo.Contains(systemOAuthClientGUIDs, client.GUID) { + continue // Skip system clients + } + + // Delete the client (including grants) + oauthClient.Delete(c, id) + } + + return nil +} diff --git a/service/admin/response.go b/service/admin/response.go index 49afbe0a..fbfb3358 100644 --- a/service/admin/response.go +++ b/service/admin/response.go @@ -114,6 +114,17 @@ type ListGroupResponse struct { Pagination *inventory.PaginationResults `json:"pagination"` } +type ListOAuthClientResponse struct { + Clients []GetOAuthClientResponse `json:"clients"` + Pagination *inventory.PaginationResults `json:"pagination"` +} + +type GetOAuthClientResponse struct { + *ent.OAuthClient + IsSystem bool `json:"is_system"` + TotalGrants int `json:"total_grants,omitempty"` +} + type HomepageSummary struct { MetricsSummary *MetricsSummary `json:"metrics_summary"` SiteURls []string `json:"site_urls"`