package inventory import ( "context" "fmt" "time" "entgo.io/ent/dialect/sql" "github.com/cloudreve/Cloudreve/v4/ent" "github.com/cloudreve/Cloudreve/v4/ent/task" "github.com/cloudreve/Cloudreve/v4/inventory/types" "github.com/cloudreve/Cloudreve/v4/pkg/conf" "github.com/cloudreve/Cloudreve/v4/pkg/hashid" "github.com/gofrs/uuid" "github.com/samber/lo" ) type ( // Ctx keys for eager loading options. LoadTaskUser struct{} TaskArgs struct { Status task.Status Type string PublicState *types.TaskPublicState PrivateState string OwnerID int CorrelationID uuid.UUID } ) type TaskClient interface { TxOperator // New creates a new task with the given args. New(ctx context.Context, task *TaskArgs) (*ent.Task, error) // Update updates the task with the given args. Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) // GetPendingTasks returns all pending tasks of given type. GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) // GetTaskByID returns the task with the given ID. GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) // SetCompleteByID sets the task with the given ID to complete. SetCompleteByID(ctx context.Context, taskID int) error // List returns a list of tasks with the given args. List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) // DeleteByIDs deletes the tasks with the given IDs. DeleteByIDs(ctx context.Context, ids ...int) error // DeleteBy deletes the tasks with the given args. DeleteBy(ctx context.Context, args *DeleteTaskArgs) error } type ( ListTaskArgs struct { *PaginationArgs Types []string Status []task.Status UserID int CorrelationID *uuid.UUID } ListTaskResult struct { *PaginationResults Tasks []*ent.Task } DeleteTaskArgs struct { NotAfter time.Time Types []string Status []task.Status } ) func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient { return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher} } type taskClient struct { maxSQlParam int hasher hashid.Encoder client *ent.Client } func (c *taskClient) SetClient(newClient *ent.Client) TxOperator { return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher} } func (c *taskClient) GetClient() *ent.Client { return c.client } func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) { stm := c.client.Task. Create(). SetType(task.Type). SetPublicState(task.PublicState) if task.PrivateState != "" { stm.SetPrivateState(task.PrivateState) } if task.OwnerID != 0 { stm.SetUserID(task.OwnerID) } if task.Status != "" { stm.SetStatus(task.Status) } if task.CorrelationID.String() != uuid.Nil.String() { stm.SetCorrelationID(task.CorrelationID) } newTask, err := stm.Save(ctx) if err != nil { return nil, fmt.Errorf("failed to create task: %w", err) } return newTask, nil } func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error { _, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx) return err } func (c *taskClient) DeleteBy(ctx context.Context, args *DeleteTaskArgs) error { query := c.client.Task. Delete(). Where(task.CreatedAtLTE(args.NotAfter)) if len(args.Status) > 0 { query.Where(task.StatusIn(args.Status...)) } if len(args.Types) > 0 { query.Where(task.TypeIn(args.Types...)) } _, err := query.Exec(ctx) return err } func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) { stm := c.client.Task.UpdateOne(task). SetPublicState(args.PublicState) task.PublicState = args.PublicState if task.PrivateState != "" { stm.SetPrivateState(task.PrivateState) task.PrivateState = args.PrivateState } if task.Status != "" { stm.SetStatus(args.Status) task.Status = args.Status } if err := stm.Exec(ctx); err != nil { return nil, fmt.Errorf("failed to create task: %w", err) } return task, nil } func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) { tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()). Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)). Where(task.TypeIn(taskType...)). All(ctx) if err != nil { return nil, err } // Anonymous user is not loaded by default, so we need to load it manually. userClient := NewUserClient(c.client) anonymous, err := userClient.AnonymousUser(ctx) for _, t := range tasks { if t.UserTasks == 0 { if err != nil { return nil, err } t.SetUser(anonymous) } } return tasks, nil } func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) { return withTaskEagerLoading(ctx, c.client.Task.Query()). Where(task.ID(taskID)). First(ctx) } func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error { _, err := c.client.Task.UpdateOneID(taskID). SetStatus(task.StatusCompleted). Save(ctx) return err } func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) { q := c.client.Task.Query() if args.UserID != 0 { q.Where(task.UserTasks(args.UserID)) } if args.Types != nil { q.Where(task.TypeIn(args.Types...)) } if args.Status != nil { q.Where(task.StatusIn(args.Status...)) } if args.CorrelationID != nil { q.Where(task.CorrelationID(*args.CorrelationID)) } q = withTaskEagerLoading(ctx, q) var ( tasks []*ent.Task err error paginationRes *PaginationResults ) if args.UseCursorPagination { tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1) } else { tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1) } if err != nil { return nil, fmt.Errorf("query failed with paginiation: %w", err) } return &ListTaskResult{ Tasks: tasks, PaginationResults: paginationRes, }, nil } func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) { pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) query.Order(task.ByID(sql.OrderDesc())) var ( pageToken *PageToken err error queryPaged = query ) if args.PageToken != "" { pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID) if err != nil { return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err) } queryPaged = query.Where(task.IDLT(pageToken.ID)) } // Use page size + 1 to determine if there are more items to come queryPaged.Limit(pageSize + 1) tasks, err := queryPaged. All(ctx) if err != nil { return nil, nil, err } // More items to come nextTokenStr := "" if len(tasks) > pageSize { lastItem := tasks[len(tasks)-2] nextToken, err := getTaskNextPageToken(c.hasher, lastItem) if err != nil { return nil, nil, fmt.Errorf("failed to generate next page token: %w", err) } nextTokenStr = nextToken } return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{ PageSize: pageSize, NextPageToken: nextTokenStr, IsCursor: true, }, nil } func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) { pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) query.Order(getTaskOrderOption(args)...) // Count total items total, err := query.Clone().Count(ctx) if err != nil { return nil, nil, err } logs, err := query. Limit(pageSize). Offset(args.Page * args.PageSize). All(ctx) if err != nil { return nil, nil, err } return logs, &PaginationResults{ PageSize: pageSize, TotalItems: total, Page: args.Page, }, nil } func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption { orderTerm := getOrderTerm(args.Order) switch args.OrderBy { default: return []task.OrderOption{task.ByID(orderTerm)} } } // getTaskNextPageToken returns the next page token for the given last task. func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) { token := &PageToken{ ID: last.ID, } return token.Encode(hasher, hashid.EncodeTaskID) } func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery { if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v { q.WithUser(func(q *ent.UserQuery) { withUserEagerLoading(ctx, q) }) } return q }