diff --git a/.build/aria2.supervisor.conf b/.build/aria2.supervisor.conf new file mode 100644 index 00000000..f7d3bca7 --- /dev/null +++ b/.build/aria2.supervisor.conf @@ -0,0 +1,7 @@ +[supervisord] +nodaemon=false + +[program:background_process] +command=aria2c --enable-rpc --save-session /cloudreve/data +autostart=true +autorestart=true \ No newline at end of file diff --git a/.build/build-assets.sh b/.build/build-assets.sh new file mode 100755 index 00000000..fc35c72a --- /dev/null +++ b/.build/build-assets.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -e +export NODE_OPTIONS="--max-old-space-size=8192" + +# This script is used to build the assets for the application. +cd assets +rm -rf build +yarn install --network-timeout 1000000 +yarn version --new-version $1 --no-git-tag-version +yarn run build + +# Copy the build files to the application directory +cd ../ +zip -r - assets/build >assets.zip +mv assets.zip application/statics \ No newline at end of file diff --git a/.build/entrypoint.sh b/.build/entrypoint.sh new file mode 100755 index 00000000..ba4875f6 --- /dev/null +++ b/.build/entrypoint.sh @@ -0,0 +1,2 @@ +supervisord -c ./aria2.supervisor.conf +./cloudreve \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index 21a187c9..00000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Build - -on: workflow_dispatch - -jobs: - build: - name: Build - runs-on: ubuntu-latest - steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - with: - clean: false - submodules: "recursive" - - run: | - git fetch --prune --unshallow --tags - - - name: Build and Release - uses: goreleaser/goreleaser-action@v4 - with: - distribution: goreleaser - version: latest - args: release --clean --skip-validate - env: - GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml deleted file mode 100644 index 6788ea16..00000000 --- a/.github/workflows/docker-release.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Build and push docker image - -on: - push: - tags: - - 3.* # triggered on every push with tag 3.* - workflow_dispatch: # or just on button clicked - -jobs: - docker-build: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v2 - - run: git fetch --prune --unshallow - - name: Setup Environments - id: envs - run: | - CLOUDREVE_LATEST_TAG=$(git describe --tags --abbrev=0) - DOCKER_IMAGE="cloudreve/cloudreve" - - echo "RELEASE_VERSION=${GITHUB_REF#refs}" - TAGS="${DOCKER_IMAGE}:latest,${DOCKER_IMAGE}:${CLOUDREVE_LATEST_TAG}" - - echo "CLOUDREVE_LATEST_TAG:${CLOUDREVE_LATEST_TAG}" - echo ::set-output name=tags::${TAGS} - - name: Setup QEMU Emulator - uses: docker/setup-qemu-action@master - with: - platforms: all - - name: Setup Docker Buildx Command - id: buildx - uses: docker/setup-buildx-action@master - - name: Login to Dockerhub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - - name: Build Docker Image and Push - id: docker_build - uses: docker/build-push-action@v2 - with: - push: true - builder: ${{ steps.buildx.outputs.name }} - context: . - file: ./Dockerfile - platforms: linux/amd64,linux/arm64,linux/arm/v7 - tags: ${{ steps.envs.outputs.tags }} - - name: Update Docker Hub Description - uses: peter-evans/dockerhub-description@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - repository: cloudreve/cloudreve - short-description: ${{ github.event.repository.description }} - - name: Image Digest - run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index 17a6ecf6..00000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Test - -on: - pull_request: - branches: - - master - push: - branches: [master] - -jobs: - test: - name: Test - runs-on: ubuntu-latest - steps: - - name: Set up Go 1.20 - uses: actions/setup-go@v2 - with: - go-version: "1.20" - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - with: - submodules: "recursive" - - - name: Build static files - run: | - mkdir assets/build - touch assets/build/test.html - - - name: Test - run: go test -coverprofile=coverage.txt -covermode=atomic ./... - - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v2 diff --git a/.gitignore b/.gitignore index e99f29b0..3b2e45c2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ # Binaries for programs and plugins -cloudreve *.exe *.exe~ *.dll @@ -8,7 +7,7 @@ cloudreve *.db *.bin /release/ -assets.zip +application/statics/assets.zip # Test binary, build with `go test -c` *.test @@ -31,3 +30,5 @@ conf/conf.ini .vscode/ dist/ +data/ +tmp/ \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 8631ace5..dc118267 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,19 +1,20 @@ -env: - - CI=false - - GENERATE_SOURCEMAP=false +version: 2 + before: hooks: - go mod tidy - - sh -c "cd assets && rm -rf build && yarn install --network-timeout 1000000 && yarn run build && cd ../ && zip -r - assets/build >assets.zip" + - chmod +x ./.build/build-assets.sh + - ./.build/build-assets.sh {{.Version}} + builds: - - - env: + - env: - CGO_ENABLED=0 binary: cloudreve ldflags: - - -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion={{.Tag}}' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit={{.ShortCommit}}' + - -s -w + - -X 'github.com/cloudreve/Cloudreve/v4/application/constants.BackendVersion={{.Tag}}' -X 'github.com/cloudreve/Cloudreve/v4/application/constants.LastCommit={{.ShortCommit}}' goos: - linux @@ -39,83 +40,73 @@ builds: goarm: 7 archives: - - format: tar.gz + - formats: ["tar.gz"] # this name template makes the OS and Arch compatible with the results of uname. name_template: >- cloudreve_{{.Tag}}_{{- .Os }}_{{ .Arch }} {{- if .Arm }}v{{ .Arm }}{{ end }} # use zip for windows archives format_overrides: - - goos: windows - format: zip + - goos: windows + formats: ["zip"] + checksum: - name_template: 'checksums.txt' + name_template: "checksums.txt" snapshot: - name_template: "{{ incpatch .Version }}-next" + version_template: "{{ incpatch .Version }}-next" + changelog: sort: asc filters: exclude: - - '^docs:' - - '^test:' + - "^docs:" + - "^test:" release: draft: true prerelease: auto - target_commitish: '{{ .Commit }}' + skip_upload: true + target_commitish: "{{ .Commit }}" name_template: "{{.Version}}" +blobs: + - provider: s3 + endpoint: https://a09fb0452382d8d745cf79d9c5ce7f7d.r2.cloudflarestorage.com + region: auto + bucket: cloudreve + directory: "{{.Version}}" + dockers: - - - dockerfile: Dockerfile + - dockerfile: Dockerfile use: buildx build_flag_templates: - "--platform=linux/amd64" goos: linux goarch: amd64 goamd64: v1 + extra_files: + - .build/aria2.supervisor.conf + - .build/entrypoint.sh image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-amd64" - - - dockerfile: Dockerfile + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-amd64" + - dockerfile: Dockerfile use: buildx build_flag_templates: - "--platform=linux/arm64" goos: linux goarch: arm64 + extra_files: + - .build/aria2.supervisor.conf + - .build/entrypoint.sh image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-arm64" - - - dockerfile: Dockerfile - use: buildx - build_flag_templates: - - "--platform=linux/arm/v6" - goos: linux - goarch: arm - goarm: '6' - image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-armv6" - - - dockerfile: Dockerfile - use: buildx - build_flag_templates: - - "--platform=linux/arm/v7" - goos: linux - goarch: arm - goarm: '7' - image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-armv7" + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-arm64" docker_manifests: - - name_template: "cloudreve/cloudreve:latest" + - name_template: "cloudreve.azurecr.io/cloudreve/pro:latest" image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-amd64" - - "cloudreve/cloudreve:{{ .Tag }}-arm64" - - "cloudreve/cloudreve:{{ .Tag }}-armv6" - - "cloudreve/cloudreve:{{ .Tag }}-armv7" - - name_template: "cloudreve/cloudreve:{{ .Tag }}" + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-amd64" + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-arm64" + - name_template: "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}" image_templates: - - "cloudreve/cloudreve:{{ .Tag }}-amd64" - - "cloudreve/cloudreve:{{ .Tag }}-arm64" - - "cloudreve/cloudreve:{{ .Tag }}-armv6" - - "cloudreve/cloudreve:{{ .Tag }}-armv7" \ No newline at end of file + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-amd64" + - "cloudreve.azurecr.io/cloudreve/pro:{{ .Tag }}-arm64" diff --git a/Dockerfile b/Dockerfile index 7b2f5ae9..d32d33a3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,28 @@ FROM alpine:latest WORKDIR /cloudreve + +COPY .build/aria2.supervisor.conf .build/entrypoint.sh ./ COPY cloudreve ./cloudreve RUN apk update \ - && apk add --no-cache tzdata \ + && apk add --no-cache tzdata vips-tools ffmpeg libreoffice aria2 supervisor font-noto font-noto-cjk \ && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && echo "Asia/Shanghai" > /etc/timezone \ && chmod +x ./cloudreve \ - && mkdir -p /data/aria2 \ - && chmod -R 766 /data/aria2 + && chmod +x ./entrypoint.sh \ + && mkdir -p ./data/temp/aria2 \ + && chmod -R 766 ./data/temp/aria2 + +ENV CR_ENABLE_ARIA2=1 \ + CR_SETTING_DEFAULT_thumb_ffmpeg_enabled=1 \ + CR_SETTING_DEFAULT_thumb_vips_enabled=1 \ + CR_SETTING_DEFAULT_thumb_libreoffice_enabled=1 \ + CR_SETTING_DEFAULT_media_meta_ffprobe=1 + +EXPOSE 5212 443 + +VOLUME ["/cloudreve/data"] -EXPOSE 5212 -VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"] +ENTRYPOINT ["sh", "./entrypoint.sh"] -ENTRYPOINT ["./cloudreve"] diff --git a/application/application.go b/application/application.go new file mode 100644 index 00000000..d0f0299d --- /dev/null +++ b/application/application.go @@ -0,0 +1,219 @@ +package application + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/crontab" + "github.com/cloudreve/Cloudreve/v4/pkg/email" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/onedrive" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/routers" + "github.com/gin-gonic/gin" +) + +type Server interface { + // Start starts the Cloudreve server. + Start() error + PrintBanner() + Close() +} + +// NewServer constructs a new Cloudreve server instance with given dependency. +func NewServer(dep dependency.Dep) Server { + return &server{ + dep: dep, + logger: dep.Logger(), + config: dep.ConfigProvider(), + } +} + +type server struct { + dep dependency.Dep + logger logging.Logger + dbClient *ent.Client + config conf.ConfigProvider + server *http.Server + kv cache.Driver + mailQueue email.Driver +} + +func (s *server) PrintBanner() { + fmt.Print(` + ___ _ _ + / __\ | ___ _ _ __| |_ __ _____ _____ + / / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \ +/ /___| | (_) | |_| | (_| | | | __/\ V / __/ +\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___| + + V` + constants.BackendVersion + ` Commit #` + constants.LastCommit + ` Pro=` + constants.IsPro + ` +================================================ + +`) +} + +func (s *server) Start() error { + // Debug 关闭时,切换为生产模式 + if !s.config.System().Debug { + gin.SetMode(gin.ReleaseMode) + } + + s.kv = s.dep.KV() + // delete all cached settings + _ = s.kv.Delete(setting.KvSettingPrefix) + + // TODO: make sure redis is connected in dep before user traffic. + if s.config.System().Mode == conf.MasterMode { + s.dbClient = s.dep.DBClient() + // TODO: make sure all dep is initialized before server start. + s.dep.LockSystem() + s.dep.UAParser() + + // Initialize OneDrive credentials + credentials, err := onedrive.RetrieveOneDriveCredentials(context.Background(), s.dep.StoragePolicyClient()) + if err != nil { + return fmt.Errorf("faield to retrieve OneDrive credentials for CredManager: %w", err) + } + if err := s.dep.CredManager().Upsert(context.Background(), credentials...); err != nil { + return fmt.Errorf("failed to upsert OneDrive credentials to CredManager: %w", err) + } + crontab.Register(setting.CronTypeOauthCredRefresh, func(ctx context.Context) { + dep := dependency.FromContext(ctx) + cred := dep.CredManager() + cred.RefreshAll(ctx) + }) + + // Initialize email queue before user traffic starts. + _ = s.dep.EmailClient(context.Background()) + + // Start all queues + s.dep.MediaMetaQueue(context.Background()).Start() + s.dep.EntityRecycleQueue(context.Background()).Start() + s.dep.IoIntenseQueue(context.Background()).Start() + s.dep.RemoteDownloadQueue(context.Background()).Start() + + // Start cron jobs + c, err := crontab.NewCron(context.Background(), s.dep) + if err != nil { + return err + } + c.Start() + + // Start node pool + if _, err := s.dep.NodePool(context.Background()); err != nil { + return err + } + } else { + s.dep.SlaveQueue(context.Background()).Start() + } + s.dep.ThumbQueue(context.Background()).Start() + + api := routers.InitRouter(s.dep) + api.TrustedPlatform = s.config.System().ProxyHeader + s.server = &http.Server{Handler: api} + + // 如果启用了SSL + if s.config.SSL().CertPath != "" { + s.logger.Info("Listening to %q", s.config.SSL().Listen) + s.server.Addr = s.config.SSL().Listen + if err := s.server.ListenAndServeTLS(s.config.SSL().CertPath, s.config.SSL().KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to listen to %q: %w", s.config.SSL().Listen, err) + } + + return nil + } + + // 如果启用了Unix + if s.config.Unix().Listen != "" { + // delete socket file before listening + if _, err := os.Stat(s.config.Unix().Listen); err == nil { + if err = os.Remove(s.config.Unix().Listen); err != nil { + return fmt.Errorf("failed to delete socket file %q: %w", s.config.Unix().Listen, err) + } + } + + s.logger.Info("Listening to %q", s.config.Unix().Listen) + if err := s.runUnix(s.server); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to listen to %q: %w", s.config.Unix().Listen, err) + } + + return nil + } + + s.logger.Info("Listening to %q", s.config.System().Listen) + s.server.Addr = s.config.System().Listen + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to listen to %q: %w", s.config.System().Listen, err) + } + return nil +} + +func (s *server) Close() { + if s.dbClient != nil { + s.logger.Info("Shutting down database connection...") + if err := s.dbClient.Close(); err != nil { + s.logger.Error("Failed to close database connection: %s", err) + } + } + + ctx := context.Background() + if conf.SystemConfig.GracePeriod != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(s.config.System().GracePeriod)*time.Second) + defer cancel() + } + + // Shutdown http server + if s.server != nil { + err := s.server.Shutdown(ctx) + if err != nil { + s.logger.Error("Failed to shutdown server: %s", err) + } + } + + if s.kv != nil { + if err := s.kv.Persist(util.DataPath(cache.DefaultCacheFile)); err != nil { + s.logger.Warning("Failed to persist cache: %s", err) + } + } + + if err := s.dep.Shutdown(ctx); err != nil { + s.logger.Warning("Failed to shutdown dependency manager: %s", err) + } +} + +func (s *server) runUnix(server *http.Server) error { + listener, err := net.Listen("unix", s.config.Unix().Listen) + if err != nil { + return err + } + + defer listener.Close() + defer os.Remove(s.config.Unix().Listen) + + if conf.UnixConfig.Perm > 0 { + err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(s.config.Unix().Perm)) + if err != nil { + s.logger.Warning( + "Failed to set permission to %q for socket file %q: %s", + s.config.Unix().Perm, + s.config.Unix().Listen, + err, + ) + } + } + + return server.Serve(listener) +} diff --git a/application/constants/constants.go b/application/constants/constants.go new file mode 100644 index 00000000..81b96b5f --- /dev/null +++ b/application/constants/constants.go @@ -0,0 +1,34 @@ +package constants + +// These values will be injected at build time, DO NOT EDIT. + +// BackendVersion 当前后端版本号 +var BackendVersion = "4.0.0-alpha.1" + +// IsPro 是否为Pro版本 +var IsPro = "false" + +var IsProBool = IsPro == "true" + +// LastCommit 最后commit id +var LastCommit = "000000" + +const ( + APIPrefix = "/api/v4" + APIPrefixSlave = "/api/v4/slave" + CrHeaderPrefix = "X-Cr-" +) + +const CloudreveScheme = "cloudreve" + +type ( + FileSystemType string +) + +const ( + FileSystemMy = FileSystemType("my") + FileSystemShare = FileSystemType("share") + FileSystemTrash = FileSystemType("trash") + FileSystemSharedWithMe = FileSystemType("shared_with_me") + FileSystemUnknown = FileSystemType("unknown") +) diff --git a/application/constants/size.go b/application/constants/size.go new file mode 100644 index 00000000..6e3d21ee --- /dev/null +++ b/application/constants/size.go @@ -0,0 +1,8 @@ +package constants + +const ( + MB = 1 << 20 + GB = 1 << 30 + TB = 1 << 40 + PB = 1 << 50 +) diff --git a/application/dependency/dependency.go b/application/dependency/dependency.go new file mode 100644 index 00000000..76a6dc7a --- /dev/null +++ b/application/dependency/dependency.go @@ -0,0 +1,874 @@ +package dependency + +import ( + "context" + "errors" + iofs "io/fs" + "net/url" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/statics" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/email" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/thumb" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-contrib/static" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/robfig/cron/v3" + "github.com/samber/lo" + "github.com/ua-parser/uap-go/uaparser" +) + +var ( + ErrorConfigPathNotSet = errors.New("config path not set") +) + +type ( + // DepCtx defines keys for dependency manager + DepCtx struct{} + // ReloadCtx force reload new dependency + ReloadCtx struct{} +) + +// Dep manages all dependencies of the server application. The default implementation is not +// concurrent safe, so all inner deps should be initialized before any goroutine starts. +type Dep interface { + // ConfigProvider Get a singleton conf.ConfigProvider instance. + ConfigProvider() conf.ConfigProvider + // Logger Get a singleton logging.Logger instance. + Logger() logging.Logger + // Statics Get a singleton fs.FS instance for embedded static resources. + Statics() iofs.FS + // ServerStaticFS Get a singleton static.ServeFileSystem instance for serving static resources. + ServerStaticFS() static.ServeFileSystem + // DBClient Get a singleton ent.Client instance for database access. + DBClient() *ent.Client + // KV Get a singleton cache.Driver instance for KV store. + KV() cache.Driver + // NavigatorStateKV Get a singleton cache.Driver instance for navigator state store. It forces use in-memory + // map instead of Redis to get better performance for complex nested linked list. + NavigatorStateKV() cache.Driver + // SettingClient Get a singleton inventory.SettingClient instance for access DB setting store. + SettingClient() inventory.SettingClient + // SettingProvider Get a singleton setting.Provider instance for access setting store in strong type. + SettingProvider() setting.Provider + // UserClient Creates a new inventory.UserClient instance for access DB user store. + UserClient() inventory.UserClient + // GroupClient Creates a new inventory.GroupClient instance for access DB group store. + GroupClient() inventory.GroupClient + // EmailClient Get a singleton email.Driver instance for sending emails. + EmailClient(ctx context.Context) email.Driver + // GeneralAuth Get a singleton auth.Auth instance for general authentication. + GeneralAuth() auth.Auth + // Shutdown the dependencies gracefully. + Shutdown(ctx context.Context) error + // FileClient Creates a new inventory.FileClient instance for access DB file store. + FileClient() inventory.FileClient + // NodeClient Creates a new inventory.NodeClient instance for access DB node store. + NodeClient() inventory.NodeClient + // DavAccountClient Creates a new inventory.DavAccountClient instance for access DB dav account store. + DavAccountClient() inventory.DavAccountClient + // DirectLinkClient Creates a new inventory.DirectLinkClient instance for access DB direct link store. + DirectLinkClient() inventory.DirectLinkClient + // HashIDEncoder Get a singleton hashid.Encoder instance for encoding/decoding hashids. + HashIDEncoder() hashid.Encoder + // TokenAuth Get a singleton auth.TokenAuth instance for token authentication. + TokenAuth() auth.TokenAuth + // LockSystem Get a singleton lock.LockSystem instance for file lock management. + LockSystem() lock.LockSystem + // ShareClient Creates a new inventory.ShareClient instance for access DB share store. + StoragePolicyClient() inventory.StoragePolicyClient + // RequestClient Creates a new request.Client instance for HTTP requests. + RequestClient(opts ...request.Option) request.Client + // ShareClient Creates a new inventory.ShareClient instance for access DB share store. + ShareClient() inventory.ShareClient + // TaskClient Creates a new inventory.TaskClient instance for access DB task store. + TaskClient() inventory.TaskClient + // ForkWithLogger create a shallow copy of dependency with a new correlated logger, used as per-request dep. + ForkWithLogger(ctx context.Context, l logging.Logger) context.Context + // MediaMetaQueue Get a singleton queue.Queue instance for media metadata processing. + MediaMetaQueue(ctx context.Context) queue.Queue + // SlaveQueue Get a singleton queue.Queue instance for slave tasks. + SlaveQueue(ctx context.Context) queue.Queue + // MediaMetaExtractor Get a singleton mediameta.Extractor instance for media metadata extraction. + MediaMetaExtractor(ctx context.Context) mediameta.Extractor + // ThumbPipeline Get a singleton thumb.Generator instance for chained thumbnail generation. + ThumbPipeline() thumb.Generator + // ThumbQueue Get a singleton queue.Queue instance for thumbnail generation. + ThumbQueue(ctx context.Context) queue.Queue + // EntityRecycleQueue Get a singleton queue.Queue instance for entity recycle. + EntityRecycleQueue(ctx context.Context) queue.Queue + // MimeDetector Get a singleton fs.MimeDetector instance for MIME type detection. + MimeDetector(ctx context.Context) mime.MimeDetector + // CredManager Get a singleton credmanager.CredManager instance for credential management. + CredManager() credmanager.CredManager + // IoIntenseQueue Get a singleton queue.Queue instance for IO intense tasks. + IoIntenseQueue(ctx context.Context) queue.Queue + // RemoteDownloadQueue Get a singleton queue.Queue instance for remote download tasks. + RemoteDownloadQueue(ctx context.Context) queue.Queue + // NodePool Get a singleton cluster.NodePool instance for node pool management. + NodePool(ctx context.Context) (cluster.NodePool, error) + // TaskRegistry Get a singleton queue.TaskRegistry instance for task registration. + TaskRegistry() queue.TaskRegistry + // WebAuthn Get a singleton webauthn.WebAuthn instance for WebAuthn authentication. + WebAuthn(ctx context.Context) (*webauthn.WebAuthn, error) + // UAParser Get a singleton uaparser.Parser instance for user agent parsing. + UAParser() *uaparser.Parser +} + +type dependency struct { + configProvider conf.ConfigProvider + logger logging.Logger + statics iofs.FS + serverStaticFS static.ServeFileSystem + dbClient *ent.Client + rawEntClient *ent.Client + kv cache.Driver + navigatorStateKv cache.Driver + settingClient inventory.SettingClient + fileClient inventory.FileClient + shareClient inventory.ShareClient + settingProvider setting.Provider + userClient inventory.UserClient + groupClient inventory.GroupClient + storagePolicyClient inventory.StoragePolicyClient + taskClient inventory.TaskClient + nodeClient inventory.NodeClient + davAccountClient inventory.DavAccountClient + directLinkClient inventory.DirectLinkClient + emailClient email.Driver + generalAuth auth.Auth + hashidEncoder hashid.Encoder + tokenAuth auth.TokenAuth + lockSystem lock.LockSystem + requestClient request.Client + ioIntenseQueue queue.Queue + thumbQueue queue.Queue + mediaMetaQueue queue.Queue + entityRecycleQueue queue.Queue + slaveQueue queue.Queue + remoteDownloadQueue queue.Queue + ioIntenseQueueTask queue.Task + mediaMeta mediameta.Extractor + thumbPipeline thumb.Generator + mimeDetector mime.MimeDetector + credManager credmanager.CredManager + nodePool cluster.NodePool + taskRegistry queue.TaskRegistry + webauthn *webauthn.WebAuthn + parser *uaparser.Parser + cron *cron.Cron + + configPath string + isPro bool + requiredDbVersion string + licenseKey string + + // Protects inner deps that can be reloaded at runtime. + mu sync.Mutex +} + +// NewDependency creates a new Dep instance for construct dependencies. +func NewDependency(opts ...Option) Dep { + d := &dependency{} + for _, o := range opts { + o.apply(d) + } + + return d +} + +// FromContext retrieves a Dep instance from context. +func FromContext(ctx context.Context) Dep { + return ctx.Value(DepCtx{}).(Dep) +} + +func (d *dependency) RequestClient(opts ...request.Option) request.Client { + if d.requestClient != nil { + return d.requestClient + } + + return request.NewClient(d.ConfigProvider(), opts...) +} + +func (d *dependency) WebAuthn(ctx context.Context) (*webauthn.WebAuthn, error) { + if d.webauthn != nil { + return d.webauthn, nil + } + + settings := d.SettingProvider() + siteBasic := settings.SiteBasic(ctx) + wConfig := &webauthn.Config{ + RPDisplayName: siteBasic.Name, + RPID: settings.SiteURL(ctx).Hostname(), + RPOrigins: lo.Map(settings.AllSiteURLs(ctx), func(item *url.URL, index int) string { + item.Path = "" + return item.String() + }), // The origin URLs allowed for WebAuthn requests + } + + return webauthn.New(wConfig) +} + +func (d *dependency) UAParser() *uaparser.Parser { + if d.parser != nil { + return d.parser + } + + d.parser = uaparser.NewFromSaved() + return d.parser +} + +func (d *dependency) ConfigProvider() conf.ConfigProvider { + if d.configProvider != nil { + return d.configProvider + } + + if d.configPath == "" { + d.panicError(ErrorConfigPathNotSet) + } + + var err error + d.configProvider, err = conf.NewIniConfigProvider(d.configPath, logging.NewConsoleLogger(logging.LevelInformational)) + if err != nil { + d.panicError(err) + } + + return d.configProvider +} + +func (d *dependency) Logger() logging.Logger { + if d.logger != nil { + return d.logger + } + + config := d.ConfigProvider() + logLevel := logging.LogLevel(config.System().LogLevel) + if config.System().Debug { + logLevel = logging.LevelDebug + } + + d.logger = logging.NewConsoleLogger(logLevel) + d.logger.Info("Logger initialized with LogLevel=%q.", logLevel) + return d.logger +} + +func (d *dependency) Statics() iofs.FS { + if d.statics != nil { + return d.statics + } + + d.statics = statics.NewStaticFS(d.Logger()) + return d.statics +} + +func (d *dependency) ServerStaticFS() static.ServeFileSystem { + if d.serverStaticFS != nil { + return d.serverStaticFS + } + + sfs, err := statics.NewServerStaticFS(d.Logger(), d.Statics(), d.isPro) + if err != nil { + d.panicError(err) + } + + d.serverStaticFS = sfs + return d.serverStaticFS +} + +func (d *dependency) DBClient() *ent.Client { + if d.dbClient != nil { + return d.dbClient + } + + if d.rawEntClient == nil { + client, err := inventory.NewRawEntClient(d.Logger(), d.ConfigProvider()) + if err != nil { + d.panicError(err) + } + + d.rawEntClient = client + } + + client, err := inventory.InitializeDBClient(d.Logger(), d.rawEntClient, d.KV(), d.requiredDbVersion) + if err != nil { + d.panicError(err) + } + + d.dbClient = client + return d.dbClient +} + +func (d *dependency) KV() cache.Driver { + if d.kv != nil { + return d.kv + } + + config := d.ConfigProvider().Redis() + if config.Server != "" { + d.kv = cache.NewRedisStore( + d.Logger(), + 10, + config.Network, + config.Server, + config.User, + config.Password, + config.DB, + ) + } else { + d.kv = cache.NewMemoStore(util.DataPath(cache.DefaultCacheFile), d.Logger()) + } + + return d.kv +} + +func (d *dependency) NavigatorStateKV() cache.Driver { + if d.navigatorStateKv != nil { + return d.navigatorStateKv + } + d.navigatorStateKv = cache.NewMemoStore("", d.Logger()) + return d.navigatorStateKv +} + +func (d *dependency) SettingClient() inventory.SettingClient { + if d.settingClient != nil { + return d.settingClient + } + + d.settingClient = inventory.NewSettingClient(d.DBClient(), d.KV()) + return d.settingClient +} + +func (d *dependency) SettingProvider() setting.Provider { + if d.settingProvider != nil { + return d.settingProvider + } + + if d.ConfigProvider().System().Mode == conf.MasterMode { + // For master mode, setting value will be retrieved in order: + // Env overwrite -> KV Store -> DB Setting Store + d.settingProvider = setting.NewProvider( + setting.NewEnvOverrideStore( + setting.NewKvSettingStore(d.KV(), + setting.NewDbSettingStore(d.SettingClient(), nil), + ), + d.Logger(), + ), + ) + } else { + // For slave mode, setting value will be retrieved in order: + // Env overwrite -> Config file overwrites -> Setting defaults in DB schema + d.settingProvider = setting.NewProvider( + setting.NewEnvOverrideStore( + setting.NewConfSettingStore(d.ConfigProvider(), + setting.NewDbDefaultStore(nil), + ), + d.Logger(), + ), + ) + } + + return d.settingProvider +} + +func (d *dependency) UserClient() inventory.UserClient { + if d.userClient != nil { + return d.userClient + } + + return inventory.NewUserClient(d.DBClient()) +} + +func (d *dependency) GroupClient() inventory.GroupClient { + if d.groupClient != nil { + return d.groupClient + } + + return inventory.NewGroupClient(d.DBClient(), d.ConfigProvider().Database().Type, d.KV()) +} + +func (d *dependency) NodeClient() inventory.NodeClient { + if d.nodeClient != nil { + return d.nodeClient + } + + return inventory.NewNodeClient(d.DBClient()) +} + +func (d *dependency) NodePool(ctx context.Context) (cluster.NodePool, error) { + reload, _ := ctx.Value(ReloadCtx{}).(bool) + if d.nodePool != nil && !reload { + return d.nodePool, nil + } + + if d.ConfigProvider().System().Mode == conf.MasterMode { + np, err := cluster.NewNodePool(ctx, d.Logger(), d.ConfigProvider(), d.SettingProvider(), d.NodeClient()) + if err != nil { + return nil, err + } + + d.nodePool = np + } else { + d.nodePool = cluster.NewSlaveDummyNodePool(ctx, d.ConfigProvider(), d.SettingProvider()) + } + + return d.nodePool, nil +} + +func (d *dependency) EmailClient(ctx context.Context) email.Driver { + d.mu.Lock() + defer d.mu.Unlock() + + if reload, _ := ctx.Value(ReloadCtx{}).(bool); reload || d.emailClient == nil { + if d.emailClient != nil { + d.emailClient.Close() + } + d.emailClient = email.NewSMTPPool(d.SettingProvider(), d.Logger()) + } + + return d.emailClient +} + +func (d *dependency) MimeDetector(ctx context.Context) mime.MimeDetector { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.mimeDetector != nil && !reload { + return d.mimeDetector + } + + d.mimeDetector = mime.NewMimeDetector(ctx, d.SettingProvider(), d.Logger()) + return d.mimeDetector +} + +func (d *dependency) MediaMetaExtractor(ctx context.Context) mediameta.Extractor { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.mediaMeta != nil && !reload { + return d.mediaMeta + } + + d.mediaMeta = mediameta.NewExtractorManager(ctx, d.SettingProvider(), d.Logger()) + return d.mediaMeta +} + +func (d *dependency) ThumbQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.thumbQueue != nil && !reload { + return d.thumbQueue + } + + if d.thumbQueue != nil { + d.thumbQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeThumb) + var ( + t inventory.TaskClient + ) + if d.ConfigProvider().System().Mode == conf.MasterMode { + t = d.TaskClient() + } + + d.thumbQueue = queue.New(d.Logger(), t, nil, d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("ThumbQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + ) + return d.thumbQueue +} + +func (d *dependency) MediaMetaQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.mediaMetaQueue != nil && !reload { + return d.mediaMetaQueue + } + + if d.mediaMetaQueue != nil { + d.mediaMetaQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeMediaMeta) + + d.mediaMetaQueue = queue.New(d.Logger(), d.TaskClient(), nil, d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("MediaMetadataQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + queue.WithResumeTaskType(queue.MediaMetaTaskType), + ) + return d.mediaMetaQueue +} + +func (d *dependency) IoIntenseQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.ioIntenseQueue != nil && !reload { + return d.ioIntenseQueue + } + + if d.ioIntenseQueue != nil { + d.ioIntenseQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeIOIntense) + + d.ioIntenseQueue = queue.New(d.Logger(), d.TaskClient(), d.TaskRegistry(), d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("IoIntenseQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + queue.WithResumeTaskType(queue.CreateArchiveTaskType, queue.ExtractArchiveTaskType, queue.RelocateTaskType), + queue.WithTaskPullInterval(10*time.Second), + ) + return d.ioIntenseQueue +} + +func (d *dependency) RemoteDownloadQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.remoteDownloadQueue != nil && !reload { + return d.remoteDownloadQueue + } + + if d.remoteDownloadQueue != nil { + d.remoteDownloadQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeRemoteDownload) + + d.remoteDownloadQueue = queue.New(d.Logger(), d.TaskClient(), d.TaskRegistry(), d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("RemoteDownloadQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + queue.WithResumeTaskType(queue.RemoteDownloadTaskType), + queue.WithTaskPullInterval(20*time.Second), + ) + return d.remoteDownloadQueue +} + +func (d *dependency) EntityRecycleQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.entityRecycleQueue != nil && !reload { + return d.entityRecycleQueue + } + + if d.entityRecycleQueue != nil { + d.entityRecycleQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeEntityRecycle) + + d.entityRecycleQueue = queue.New(d.Logger(), d.TaskClient(), nil, d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("EntityRecycleQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + queue.WithResumeTaskType(queue.EntityRecycleRoutineTaskType, queue.ExplicitEntityRecycleTaskType, queue.UploadSentinelCheckTaskType), + queue.WithTaskPullInterval(10*time.Second), + ) + return d.entityRecycleQueue +} + +func (d *dependency) SlaveQueue(ctx context.Context) queue.Queue { + d.mu.Lock() + defer d.mu.Unlock() + + _, reload := ctx.Value(ReloadCtx{}).(bool) + if d.slaveQueue != nil && !reload { + return d.slaveQueue + } + + if d.slaveQueue != nil { + d.slaveQueue.Shutdown() + } + + settings := d.SettingProvider() + queueSetting := settings.Queue(context.Background(), setting.QueueTypeSlave) + + d.slaveQueue = queue.New(d.Logger(), nil, nil, d, + queue.WithBackoffFactor(queueSetting.BackoffFactor), + queue.WithMaxRetry(queueSetting.MaxRetry), + queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration), + queue.WithRetryDelay(queueSetting.RetryDelay), + queue.WithWorkerCount(queueSetting.WorkerNum), + queue.WithName("SlaveQueue"), + queue.WithMaxTaskExecution(queueSetting.MaxExecution), + ) + return d.slaveQueue +} + +func (d *dependency) GeneralAuth() auth.Auth { + if d.generalAuth != nil { + return d.generalAuth + } + + var secretKey string + if d.ConfigProvider().System().Mode == conf.MasterMode { + secretKey = d.SettingProvider().SecretKey(context.Background()) + } else { + secretKey = d.ConfigProvider().Slave().Secret + if secretKey == "" { + d.panicError(errors.New("SlaveSecret is not set, please specify it in config file")) + } + } + + d.generalAuth = auth.HMACAuth{ + SecretKey: []byte(secretKey), + } + + return d.generalAuth +} + +func (d *dependency) FileClient() inventory.FileClient { + if d.fileClient != nil { + return d.fileClient + } + + return inventory.NewFileClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder()) +} + +func (d *dependency) ShareClient() inventory.ShareClient { + if d.shareClient != nil { + return d.shareClient + } + + return inventory.NewShareClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder()) +} + +func (d *dependency) TaskClient() inventory.TaskClient { + if d.taskClient != nil { + return d.taskClient + } + + return inventory.NewTaskClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder()) +} + +func (d *dependency) DavAccountClient() inventory.DavAccountClient { + if d.davAccountClient != nil { + return d.davAccountClient + } + + return inventory.NewDavAccountClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder()) +} + +func (d *dependency) DirectLinkClient() inventory.DirectLinkClient { + if d.directLinkClient != nil { + return d.directLinkClient + } + + return inventory.NewDirectLinkClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder()) +} + +func (d *dependency) HashIDEncoder() hashid.Encoder { + if d.hashidEncoder != nil { + return d.hashidEncoder + } + + encoder, err := hashid.New(d.SettingProvider().HashIDSalt(context.Background())) + if err != nil { + d.panicError(err) + } + + d.hashidEncoder = encoder + return d.hashidEncoder +} + +func (d *dependency) CredManager() credmanager.CredManager { + if d.credManager != nil { + return d.credManager + } + + if d.ConfigProvider().System().Mode == conf.MasterMode { + d.credManager = credmanager.New(d.KV()) + } else { + d.credManager = credmanager.NewSlaveManager(d.KV(), d.ConfigProvider()) + } + return d.credManager +} + +func (d *dependency) TokenAuth() auth.TokenAuth { + if d.tokenAuth != nil { + return d.tokenAuth + } + + d.tokenAuth = auth.NewTokenAuth(d.HashIDEncoder(), d.SettingProvider(), + []byte(d.SettingProvider().SecretKey(context.Background())), d.UserClient(), d.Logger()) + return d.tokenAuth +} + +func (d *dependency) LockSystem() lock.LockSystem { + if d.lockSystem != nil { + return d.lockSystem + } + + d.lockSystem = lock.NewMemLS(d.HashIDEncoder(), d.Logger()) + return d.lockSystem +} + +func (d *dependency) StoragePolicyClient() inventory.StoragePolicyClient { + if d.storagePolicyClient != nil { + return d.storagePolicyClient + } + + return inventory.NewStoragePolicyClient(d.DBClient(), d.KV()) +} + +func (d *dependency) ThumbPipeline() thumb.Generator { + if d.thumbPipeline != nil { + return d.thumbPipeline + } + + d.thumbPipeline = thumb.NewPipeline(d.SettingProvider(), d.Logger()) + return d.thumbPipeline +} + +func (d *dependency) TaskRegistry() queue.TaskRegistry { + if d.taskRegistry != nil { + return d.taskRegistry + } + + d.taskRegistry = queue.NewTaskRegistry() + return d.taskRegistry +} + +func (d *dependency) Shutdown(ctx context.Context) error { + d.mu.Lock() + + if d.emailClient != nil { + d.emailClient.Close() + } + + wg := sync.WaitGroup{} + + if d.mediaMetaQueue != nil { + wg.Add(1) + go func() { + d.mediaMetaQueue.Shutdown() + defer wg.Done() + }() + } + + if d.thumbQueue != nil { + wg.Add(1) + go func() { + d.thumbQueue.Shutdown() + defer wg.Done() + }() + } + + if d.ioIntenseQueue != nil { + wg.Add(1) + go func() { + d.ioIntenseQueue.Shutdown() + defer wg.Done() + }() + } + + if d.entityRecycleQueue != nil { + wg.Add(1) + go func() { + d.entityRecycleQueue.Shutdown() + defer wg.Done() + }() + } + + if d.slaveQueue != nil { + wg.Add(1) + go func() { + d.slaveQueue.Shutdown() + defer wg.Done() + }() + } + + if d.remoteDownloadQueue != nil { + wg.Add(1) + go func() { + d.remoteDownloadQueue.Shutdown() + defer wg.Done() + }() + } + + d.mu.Unlock() + wg.Wait() + + return nil +} + +func (d *dependency) panicError(err error) { + if d.logger != nil { + d.logger.Panic("Fatal error in dependency initialization: %s", err) + } + + panic(err) +} + +func (d *dependency) ForkWithLogger(ctx context.Context, l logging.Logger) context.Context { + dep := &dependencyCorrelated{ + l: l, + dependency: d, + } + return context.WithValue(ctx, DepCtx{}, dep) +} + +type dependencyCorrelated struct { + l logging.Logger + *dependency +} + +func (d *dependencyCorrelated) Logger() logging.Logger { + return d.l +} diff --git a/application/dependency/options.go b/application/dependency/options.go new file mode 100644 index 00000000..9c92319e --- /dev/null +++ b/application/dependency/options.go @@ -0,0 +1,165 @@ +package dependency + +import ( + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/email" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gin-contrib/static" + "io/fs" +) + +// Option 发送请求的额外设置 +type Option interface { + apply(*dependency) +} + +type optionFunc func(*dependency) + +func (f optionFunc) apply(o *dependency) { + f(o) +} + +// WithConfigPath Set the path of the config file. +func WithConfigPath(p string) Option { + return optionFunc(func(o *dependency) { + o.configPath = p + }) +} + +// WithLogger Set the default logging. +func WithLogger(l logging.Logger) Option { + return optionFunc(func(o *dependency) { + o.logger = l + }) +} + +// WithConfigProvider Set the default config provider. +func WithConfigProvider(c conf.ConfigProvider) Option { + return optionFunc(func(o *dependency) { + o.configProvider = c + }) +} + +// WithStatics Set the default statics FS. +func WithStatics(c fs.FS) Option { + return optionFunc(func(o *dependency) { + o.statics = c + }) +} + +// WithServerStaticFS Set the default statics FS for server. +func WithServerStaticFS(c static.ServeFileSystem) Option { + return optionFunc(func(o *dependency) { + o.serverStaticFS = c + }) +} + +// WithProFlag Set if current instance is a pro version. +func WithProFlag(c bool) Option { + return optionFunc(func(o *dependency) { + o.isPro = c + }) +} + +func WithLicenseKey(c string) Option { + return optionFunc(func(o *dependency) { + o.licenseKey = c + }) +} + +// WithRawEntClient Set the default raw ent client. +func WithRawEntClient(c *ent.Client) Option { + return optionFunc(func(o *dependency) { + o.rawEntClient = c + }) +} + +// WithDbClient Set the default ent client. +func WithDbClient(c *ent.Client) Option { + return optionFunc(func(o *dependency) { + o.dbClient = c + }) +} + +// WithRequiredDbVersion Set the required db version. +func WithRequiredDbVersion(c string) Option { + return optionFunc(func(o *dependency) { + o.requiredDbVersion = c + }) +} + +// WithKV Set the default KV store driverold +func WithKV(c cache.Driver) Option { + return optionFunc(func(o *dependency) { + o.kv = c + }) +} + +// WithSettingClient Set the default setting client +func WithSettingClient(s inventory.SettingClient) Option { + return optionFunc(func(o *dependency) { + o.settingClient = s + }) +} + +// WithSettingProvider Set the default setting provider +func WithSettingProvider(s setting.Provider) Option { + return optionFunc(func(o *dependency) { + o.settingProvider = s + }) +} + +// WithUserClient Set the default user client +func WithUserClient(s inventory.UserClient) Option { + return optionFunc(func(o *dependency) { + o.userClient = s + }) +} + +// WithEmailClient Set the default email client +func WithEmailClient(s email.Driver) Option { + return optionFunc(func(o *dependency) { + o.emailClient = s + }) +} + +// WithGeneralAuth Set the default general auth +func WithGeneralAuth(s auth.Auth) Option { + return optionFunc(func(o *dependency) { + o.generalAuth = s + }) +} + +// WithHashIDEncoder Set the default hash id encoder +func WithHashIDEncoder(s hashid.Encoder) Option { + return optionFunc(func(o *dependency) { + o.hashidEncoder = s + }) +} + +// WithTokenAuth Set the default token auth +func WithTokenAuth(s auth.TokenAuth) Option { + return optionFunc(func(o *dependency) { + o.tokenAuth = s + }) +} + +// WithFileClient Set the default file client +func WithFileClient(s inventory.FileClient) Option { + return optionFunc(func(o *dependency) { + o.fileClient = s + }) +} + +// WithShareClient Set the default share client +func WithShareClient(s inventory.ShareClient) Option { + return optionFunc(func(o *dependency) { + o.shareClient = s + }) +} diff --git a/application/migrator/avatars.go b/application/migrator/avatars.go new file mode 100644 index 00000000..d9d392ef --- /dev/null +++ b/application/migrator/avatars.go @@ -0,0 +1,47 @@ +package migrator + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/cloudreve/Cloudreve/v4/pkg/util" +) + +func migrateAvatars(m *Migrator) error { + m.l.Info("Migrating avatars files...") + avatarRoot := util.RelativePath(m.state.V3AvatarPath) + + for uid, _ := range m.state.UserIDs { + avatarPath := filepath.Join(avatarRoot, fmt.Sprintf("avatar_%d_2.png", uid)) + + // check if file exists + if util.Exists(avatarPath) { + m.l.Info("Migrating avatar for user %d", uid) + // Copy to v4 avatar path + v4Path := filepath.Join(util.DataPath("avatar"), fmt.Sprintf("avatar_%d.png", uid)) + + // copy + origin, err := os.Open(avatarPath) + if err != nil { + return fmt.Errorf("failed to open avatar file: %w", err) + } + defer origin.Close() + + dest, err := util.CreatNestedFile(v4Path) + if err != nil { + return fmt.Errorf("failed to create avatar file: %w", err) + } + defer dest.Close() + + _, err = io.Copy(dest, origin) + + if err != nil { + m.l.Warning("Failed to copy avatar file: %s, skipping...", err) + } + } + } + + return nil +} diff --git a/application/migrator/conf/conf.go b/application/migrator/conf/conf.go new file mode 100644 index 00000000..e34fc303 --- /dev/null +++ b/application/migrator/conf/conf.go @@ -0,0 +1,124 @@ +package conf + +import ( + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/go-ini/ini" + "github.com/go-playground/validator/v10" +) + +// database 数据库 +type database struct { + Type string + User string + Password string + Host string + Name string + TablePrefix string + DBFile string + Port int + Charset string + UnixSocket bool +} + +// system 系统通用配置 +type system struct { + Mode string `validate:"eq=master|eq=slave"` + Listen string `validate:"required"` + Debug bool + SessionSecret string + HashIDSalt string + GracePeriod int `validate:"gte=0"` + ProxyHeader string `validate:"required_with=Listen"` +} + +type ssl struct { + CertPath string `validate:"omitempty,required"` + KeyPath string `validate:"omitempty,required"` + Listen string `validate:"required"` +} + +type unix struct { + Listen string + Perm uint32 +} + +// slave 作为slave存储端配置 +type slave struct { + Secret string `validate:"omitempty,gte=64"` + CallbackTimeout int `validate:"omitempty,gte=1"` + SignatureTTL int `validate:"omitempty,gte=1"` +} + +// redis 配置 +type redis struct { + Network string + Server string + User string + Password string + DB string +} + +// 跨域配置 +type cors struct { + AllowOrigins []string + AllowMethods []string + AllowHeaders []string + AllowCredentials bool + ExposeHeaders []string + SameSite string + Secure bool +} + +var cfg *ini.File + +// Init 初始化配置文件 +func Init(l logging.Logger, path string) error { + var err error + + cfg, err = ini.Load(path) + if err != nil { + l.Error("Failed to parse config file %q: %s", path, err) + return err + } + + sections := map[string]interface{}{ + "Database": DatabaseConfig, + "System": SystemConfig, + "SSL": SSLConfig, + "UnixSocket": UnixConfig, + "Redis": RedisConfig, + "CORS": CORSConfig, + "Slave": SlaveConfig, + } + for sectionName, sectionStruct := range sections { + err = mapSection(sectionName, sectionStruct) + if err != nil { + l.Error("Failed to parse config section %q: %s", sectionName, err) + return err + } + } + + // 映射数据库配置覆盖 + for _, key := range cfg.Section("OptionOverwrite").Keys() { + OptionOverwrite[key.Name()] = key.Value() + } + + return nil +} + +// mapSection 将配置文件的 Section 映射到结构体上 +func mapSection(section string, confStruct interface{}) error { + err := cfg.Section(section).MapTo(confStruct) + if err != nil { + return err + } + + // 验证合法性 + validate := validator.New() + err = validate.Struct(confStruct) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/conf/defaults.go b/application/migrator/conf/defaults.go similarity index 100% rename from pkg/conf/defaults.go rename to application/migrator/conf/defaults.go diff --git a/application/migrator/directlink.go b/application/migrator/directlink.go new file mode 100644 index 00000000..cb18d3d4 --- /dev/null +++ b/application/migrator/directlink.go @@ -0,0 +1,82 @@ +package migrator + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +func (m *Migrator) migrateDirectLink() error { + m.l.Info("Migrating direct links...") + batchSize := 1000 + offset := m.state.DirectLinkOffset + ctx := context.Background() + + if m.state.DirectLinkOffset > 0 { + m.l.Info("Resuming direct link migration from offset %d", offset) + } + + for { + m.l.Info("Migrating direct links with offset %d", offset) + var directLinks []model.SourceLink + if err := model.DB.Limit(batchSize).Offset(offset).Find(&directLinks).Error; err != nil { + return fmt.Errorf("failed to list v3 direct links: %w", err) + } + + if len(directLinks) == 0 { + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting direct link ID sequence for postgres...") + m.v4client.DirectLink.ExecContext(ctx, "SELECT SETVAL('direct_links_id_seq', (SELECT MAX(id) FROM direct_links))") + } + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, dl := range directLinks { + sourceId := int(dl.FileID) + m.state.LastFolderID + // check if file exists + _, err = tx.File.Query().Where(file.ID(sourceId)).First(ctx) + if err != nil { + m.l.Warning("File %d not found, skipping direct link %d", sourceId, dl.ID) + continue + } + + stm := tx.DirectLink.Create(). + SetCreatedAt(formatTime(dl.CreatedAt)). + SetUpdatedAt(formatTime(dl.UpdatedAt)). + SetRawID(int(dl.ID)). + SetFileID(sourceId). + SetName(dl.Name). + SetDownloads(dl.Downloads). + SetSpeed(0) + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to create direct link %d: %w", dl.ID, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + offset += batchSize + m.state.DirectLinkOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after direct link batch: %s", err) + } else { + m.l.Info("Saved migration state after processing this batch") + } + } + + return nil + +} diff --git a/application/migrator/file.go b/application/migrator/file.go new file mode 100644 index 00000000..65390881 --- /dev/null +++ b/application/migrator/file.go @@ -0,0 +1,189 @@ +package migrator + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +func (m *Migrator) migrateFile() error { + m.l.Info("Migrating files...") + batchSize := 1000 + offset := m.state.FileOffset + ctx := context.Background() + + if m.state.FileConflictRename == nil { + m.state.FileConflictRename = make(map[uint]string) + } + + if m.state.EntitySources == nil { + m.state.EntitySources = make(map[string]int) + } + + if offset > 0 { + m.l.Info("Resuming file migration from offset %d", offset) + } + +out: + for { + m.l.Info("Migrating files with offset %d", offset) + var files []model.File + if err := model.DB.Limit(batchSize).Offset(offset).Find(&files).Error; err != nil { + return fmt.Errorf("failed to list v3 files: %w", err) + } + + if len(files) == 0 { + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting file ID sequence for postgres...") + m.v4client.File.ExecContext(ctx, "SELECT SETVAL('files_id_seq', (SELECT MAX(id) FROM files))") + } + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, f := range files { + if _, ok := m.state.FolderIDs[int(f.FolderID)]; !ok { + m.l.Warning("Folder ID %d for file %d not found, skipping", f.FolderID, f.ID) + continue + } + + if _, ok := m.state.UserIDs[int(f.UserID)]; !ok { + m.l.Warning("User ID %d for file %d not found, skipping", f.UserID, f.ID) + continue + } + + if _, ok := m.state.PolicyIDs[int(f.PolicyID)]; !ok { + m.l.Warning("Policy ID %d for file %d not found, skipping", f.PolicyID, f.ID) + continue + } + + metadata := make(map[string]string) + if f.Metadata != "" { + json.Unmarshal([]byte(f.Metadata), &metadata) + } + + var ( + thumbnail *ent.Entity + entity *ent.Entity + err error + ) + + if metadata[model.ThumbStatusMetadataKey] == model.ThumbStatusExist { + size := int64(0) + if m.state.LocalPolicyIDs[int(f.PolicyID)] { + thumbFile, err := os.Stat(f.SourceName + m.state.ThumbSuffix) + if err == nil { + size = thumbFile.Size() + } + m.l.Warning("Thumbnail file %s for file %d not found, use 0 size", f.SourceName+m.state.ThumbSuffix, f.ID) + } + // Insert thumbnail entity + thumbnail, err = m.insertEntity(tx, f.SourceName+m.state.ThumbSuffix, int(types.EntityTypeThumbnail), int(f.PolicyID), int(f.UserID), size) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to insert thumbnail entity: %w", err) + } + } + + // Insert file version entity + entity, err = m.insertEntity(tx, f.SourceName, int(types.EntityTypeVersion), int(f.PolicyID), int(f.UserID), int64(f.Size)) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to insert file version entity: %w", err) + } + + fname := f.Name + if _, ok := m.state.FileConflictRename[f.ID]; ok { + fname = m.state.FileConflictRename[f.ID] + } + + stm := tx.File.Create(). + SetCreatedAt(formatTime(f.CreatedAt)). + SetUpdatedAt(formatTime(f.UpdatedAt)). + SetName(fname). + SetRawID(int(f.ID) + m.state.LastFolderID). + SetOwnerID(int(f.UserID)). + SetSize(int64(f.Size)). + SetPrimaryEntity(entity.ID). + SetFileChildren(int(f.FolderID)). + SetType(int(types.FileTypeFile)). + SetStoragePoliciesID(int(f.PolicyID)). + AddEntities(entity) + + if thumbnail != nil { + stm.AddEntities(thumbnail) + } + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + if ent.IsConstraintError(err) { + if _, ok := m.state.FileConflictRename[f.ID]; ok { + return fmt.Errorf("file %d already exists, but new name is already in conflict rename map, please resolve this manually", f.ID) + } + + m.l.Warning("File %d already exists, will retry with new name in next batch", f.ID) + m.state.FileConflictRename[f.ID] = fmt.Sprintf("%d_%s", f.ID, f.Name) + continue out + } + return fmt.Errorf("failed to create file %d: %w", f.ID, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + offset += batchSize + m.state.FileOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after file batch: %s", err) + } else { + m.l.Info("Saved migration state after processing this batch") + } + } + + return nil +} + +func (m *Migrator) insertEntity(tx *ent.Tx, source string, entityType, policyID, createdBy int, size int64) (*ent.Entity, error) { + + // find existing one + entityKey := strconv.Itoa(policyID) + "+" + source + if existingId, ok := m.state.EntitySources[entityKey]; ok { + existing, err := tx.Entity.UpdateOneID(existingId). + AddReferenceCount(1). + Save(context.Background()) + if err == nil { + return existing, nil + } + m.l.Warning("Failed to update existing entity %d: %s, fallback to create new one.", existingId, err) + } + + // create new one + e, err := tx.Entity.Create(). + SetSource(source). + SetType(entityType). + SetSize(size). + SetStoragePolicyEntities(policyID). + SetCreatedBy(createdBy). + SetReferenceCount(1). + Save(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to create new entity: %w", err) + } + + m.state.EntitySources[entityKey] = e.ID + return e, nil +} diff --git a/application/migrator/folders.go b/application/migrator/folders.go new file mode 100644 index 00000000..beea2d78 --- /dev/null +++ b/application/migrator/folders.go @@ -0,0 +1,147 @@ +package migrator + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +func (m *Migrator) migrateFolders() error { + m.l.Info("Migrating folders...") + batchSize := 1000 + // Start from the saved offset if available + offset := m.state.FolderOffset + ctx := context.Background() + foldersCount := 0 + + if m.state.FolderIDs == nil { + m.state.FolderIDs = make(map[int]bool) + } + + if offset > 0 { + m.l.Info("Resuming folder migration from offset %d", offset) + } + + for { + m.l.Info("Migrating folders with offset %d", offset) + var folders []model.Folder + if err := model.DB.Limit(batchSize).Offset(offset).Find(&folders).Error; err != nil { + return fmt.Errorf("failed to list v3 folders: %w", err) + } + + if len(folders) == 0 { + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + batchFoldersCount := 0 + for _, f := range folders { + if _, ok := m.state.UserIDs[int(f.OwnerID)]; !ok { + m.l.Warning("Owner ID %d not found, skipping folder %d", f.OwnerID, f.ID) + continue + } + + isRoot := f.ParentID == nil + if isRoot { + f.Name = "" + } else if *f.ParentID == 0 { + m.l.Warning("Parent ID %d not found, skipping folder %d", *f.ParentID, f.ID) + continue + } + + stm := tx.File.Create(). + SetRawID(int(f.ID)). + SetType(int(types.FileTypeFolder)). + SetCreatedAt(formatTime(f.CreatedAt)). + SetUpdatedAt(formatTime(f.UpdatedAt)). + SetName(f.Name). + SetOwnerID(int(f.OwnerID)) + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to create folder %d: %w", f.ID, err) + } + + m.state.FolderIDs[int(f.ID)] = true + m.state.LastFolderID = int(f.ID) + + foldersCount++ + batchFoldersCount++ + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + // Update the offset in state and save after each batch + offset += batchSize + m.state.FolderOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after folder batch: %s", err) + } else { + m.l.Info("Saved migration state after processing %d folders in this batch", batchFoldersCount) + } + } + + m.l.Info("Successfully migrated %d folders", foldersCount) + return nil +} + +func (m *Migrator) migrateFolderParent() error { + m.l.Info("Migrating folder parent...") + batchSize := 1000 + offset := m.state.FolderParentOffset + ctx := context.Background() + + for { + m.l.Info("Migrating folder parent with offset %d", offset) + var folderParents []model.Folder + if err := model.DB.Limit(batchSize).Offset(offset).Find(&folderParents).Error; err != nil { + return fmt.Errorf("failed to list v3 folder parents: %w", err) + } + + if len(folderParents) == 0 { + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, f := range folderParents { + if f.ParentID != nil { + if _, ok := m.state.FolderIDs[int(*f.ParentID)]; !ok { + m.l.Warning("Folder ID %d not found, skipping folder parent %d", f.ID, f.ID) + continue + } + + if _, err := tx.File.UpdateOneID(int(f.ID)).SetParentID(int(*f.ParentID)).Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to update folder parent %d: %w", f.ID, err) + } + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + // Update the offset in state and save after each batch + offset += batchSize + m.state.FolderParentOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after folder parent batch: %s", err) + } + } + + return nil +} diff --git a/application/migrator/group.go b/application/migrator/group.go new file mode 100644 index 00000000..8f8ae79d --- /dev/null +++ b/application/migrator/group.go @@ -0,0 +1,92 @@ +package migrator + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/samber/lo" +) + +func (m *Migrator) migrateGroup() error { + m.l.Info("Migrating groups...") + + var groups []model.Group + if err := model.DB.Find(&groups).Error; err != nil { + return fmt.Errorf("failed to list v3 groups: %w", err) + } + + for _, group := range groups { + cap := &boolset.BooleanSet{} + var ( + opts model.GroupOption + policies []int + ) + if err := json.Unmarshal([]byte(group.Options), &opts); err != nil { + return fmt.Errorf("failed to unmarshal options for group %q: %w", group.Name, err) + } + + if err := json.Unmarshal([]byte(group.Policies), &policies); err != nil { + return fmt.Errorf("failed to unmarshal policies for group %q: %w", group.Name, err) + } + + policies = lo.Filter(policies, func(id int, _ int) bool { + _, exist := m.state.PolicyIDs[id] + return exist + }) + + newOpts := &types.GroupSetting{ + CompressSize: int64(opts.CompressSize), + DecompressSize: int64(opts.DecompressSize), + RemoteDownloadOptions: opts.Aria2Options, + SourceBatchSize: opts.SourceBatchSize, + RedirectedSource: opts.RedirectedSource, + Aria2BatchSize: opts.Aria2BatchSize, + MaxWalkedFiles: 100000, + TrashRetention: 7 * 24 * 3600, + } + + boolset.Sets(map[types.GroupPermission]bool{ + types.GroupPermissionIsAdmin: group.ID == 1, + types.GroupPermissionIsAnonymous: group.ID == 3, + types.GroupPermissionShareDownload: opts.ShareDownload, + types.GroupPermissionWebDAV: group.WebDAVEnabled, + types.GroupPermissionArchiveDownload: opts.ArchiveDownload, + types.GroupPermissionArchiveTask: opts.ArchiveTask, + types.GroupPermissionWebDAVProxy: opts.WebDAVProxy, + types.GroupPermissionRemoteDownload: opts.Aria2, + types.GroupPermissionAdvanceDelete: opts.AdvanceDelete, + types.GroupPermissionShare: group.ShareEnabled, + types.GroupPermissionRedirectedSource: opts.RedirectedSource, + }, cap) + + stm := m.v4client.Group.Create(). + SetRawID(int(group.ID)). + SetCreatedAt(formatTime(group.CreatedAt)). + SetUpdatedAt(formatTime(group.UpdatedAt)). + SetName(group.Name). + SetMaxStorage(int64(group.MaxStorage)). + SetSpeedLimit(group.SpeedLimit). + SetPermissions(cap). + SetSettings(newOpts) + + if len(policies) > 0 { + stm.SetStoragePoliciesID(policies[0]) + } + + if _, err := stm.Save(context.Background()); err != nil { + return fmt.Errorf("failed to create group %q: %w", group.Name, err) + } + } + + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting group ID sequence for postgres...") + m.v4client.Group.ExecContext(context.Background(), "SELECT SETVAL('groups_id_seq', (SELECT MAX(id) FROM groups))") + } + + return nil +} diff --git a/application/migrator/migrator.go b/application/migrator/migrator.go new file mode 100644 index 00000000..5bb53520 --- /dev/null +++ b/application/migrator/migrator.go @@ -0,0 +1,314 @@ +package migrator + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/application/migrator/conf" + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" +) + +// State stores the migration progress +type State struct { + PolicyIDs map[int]bool `json:"policy_ids,omitempty"` + LocalPolicyIDs map[int]bool `json:"local_policy_ids,omitempty"` + UserIDs map[int]bool `json:"user_ids,omitempty"` + FolderIDs map[int]bool `json:"folder_ids,omitempty"` + EntitySources map[string]int `json:"entity_sources,omitempty"` + LastFolderID int `json:"last_folder_id,omitempty"` + Step int `json:"step,omitempty"` + UserOffset int `json:"user_offset,omitempty"` + FolderOffset int `json:"folder_offset,omitempty"` + FileOffset int `json:"file_offset,omitempty"` + ShareOffset int `json:"share_offset,omitempty"` + GiftCodeOffset int `json:"gift_code_offset,omitempty"` + DirectLinkOffset int `json:"direct_link_offset,omitempty"` + WebdavOffset int `json:"webdav_offset,omitempty"` + StoragePackOffset int `json:"storage_pack_offset,omitempty"` + FileConflictRename map[uint]string `json:"file_conflict_rename,omitempty"` + FolderParentOffset int `json:"folder_parent_offset,omitempty"` + ThumbSuffix string `json:"thumb_suffix,omitempty"` + V3AvatarPath string `json:"v3_avatar_path,omitempty"` +} + +// Step identifiers for migration phases +const ( + StepInitial = 0 + StepSchema = 1 + StepSettings = 2 + StepNode = 3 + StepPolicy = 4 + StepGroup = 5 + StepUser = 6 + StepFolders = 7 + StepFolderParent = 8 + StepFile = 9 + StepShare = 10 + StepDirectLink = 11 + Step_CommunityPlaceholder1 = 12 + Step_CommunityPlaceholder2 = 13 + StepAvatar = 14 + StepWebdav = 15 + StepCompleted = 16 + StateFileName = "migration_state.json" +) + +type Migrator struct { + dep dependency.Dep + l logging.Logger + v4client *ent.Client + state *State + statePath string +} + +func NewMigrator(dep dependency.Dep, v3ConfPath string) (*Migrator, error) { + m := &Migrator{ + dep: dep, + l: dep.Logger(), + state: &State{ + PolicyIDs: make(map[int]bool), + UserIDs: make(map[int]bool), + Step: StepInitial, + UserOffset: 0, + FolderOffset: 0, + }, + } + + // Determine state file path + configDir := filepath.Dir(v3ConfPath) + m.statePath = filepath.Join(configDir, StateFileName) + + // Try to load existing state + if util.Exists(m.statePath) { + m.l.Info("Found existing migration state file, loading from %s", m.statePath) + if err := m.loadState(); err != nil { + return nil, fmt.Errorf("failed to load migration state: %w", err) + } + + stepName := "unknown" + switch m.state.Step { + case StepInitial: + stepName = "initial" + case StepSchema: + stepName = "schema creation" + case StepSettings: + stepName = "settings migration" + case StepNode: + stepName = "node migration" + case StepPolicy: + stepName = "policy migration" + case StepGroup: + stepName = "group migration" + case StepUser: + stepName = "user migration" + case StepFolders: + stepName = "folders migration" + case StepCompleted: + stepName = "completed" + case StepWebdav: + stepName = "webdav migration" + case StepAvatar: + stepName = "avatar migration" + + } + + m.l.Info("Resumed migration from step %d (%s)", m.state.Step, stepName) + + // Log batch information if applicable + if m.state.Step == StepUser && m.state.UserOffset > 0 { + m.l.Info("Will resume user migration from batch offset %d", m.state.UserOffset) + } + if m.state.Step == StepFolders && m.state.FolderOffset > 0 { + m.l.Info("Will resume folder migration from batch offset %d", m.state.FolderOffset) + } + } + + err := conf.Init(m.dep.Logger(), v3ConfPath) + if err != nil { + return nil, err + } + + err = model.Init() + if err != nil { + return nil, err + } + + v4client, err := inventory.NewRawEntClient(m.l, m.dep.ConfigProvider()) + if err != nil { + return nil, err + } + + m.v4client = v4client + return m, nil +} + +// saveState persists migration state to file +func (m *Migrator) saveState() error { + data, err := json.Marshal(m.state) + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + + return os.WriteFile(m.statePath, data, 0644) +} + +// loadState reads migration state from file +func (m *Migrator) loadState() error { + data, err := os.ReadFile(m.statePath) + if err != nil { + return fmt.Errorf("failed to read state file: %w", err) + } + + return json.Unmarshal(data, m.state) +} + +// updateStep updates current step and persists state +func (m *Migrator) updateStep(step int) error { + m.state.Step = step + return m.saveState() +} + +func (m *Migrator) Migrate() error { + // Continue from the current step + if m.state.Step <= StepSchema { + m.l.Info("Creating basic v4 table schema...") + if err := m.v4client.Schema.Create(context.Background()); err != nil { + return fmt.Errorf("failed creating schema resources: %w", err) + } + if err := m.updateStep(StepSettings); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepSettings { + if err := m.migrateSettings(); err != nil { + return err + } + if err := m.updateStep(StepNode); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepNode { + if err := m.migrateNode(); err != nil { + return err + } + if err := m.updateStep(StepPolicy); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepPolicy { + allPolicyIDs, err := m.migratePolicy() + if err != nil { + return err + } + m.state.PolicyIDs = allPolicyIDs + if err := m.updateStep(StepGroup); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepGroup { + if err := m.migrateGroup(); err != nil { + return err + } + if err := m.updateStep(StepUser); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepUser { + if err := m.migrateUser(); err != nil { + m.saveState() + return err + } + // Reset user offset after completion + m.state.UserOffset = 0 + if err := m.updateStep(StepFolders); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepFolders { + if err := m.migrateFolders(); err != nil { + m.saveState() + return err + } + // Reset folder offset after completion + m.state.FolderOffset = 0 + if err := m.updateStep(StepFolderParent); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepFolderParent { + if err := m.migrateFolderParent(); err != nil { + return err + } + if err := m.updateStep(StepFile); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepFile { + if err := m.migrateFile(); err != nil { + return err + } + if err := m.updateStep(StepShare); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepShare { + if err := m.migrateShare(); err != nil { + return err + } + if err := m.updateStep(StepDirectLink); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepDirectLink { + if err := m.migrateDirectLink(); err != nil { + return err + } + if err := m.updateStep(StepAvatar); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepAvatar { + if err := migrateAvatars(m); err != nil { + return err + } + if err := m.updateStep(StepWebdav); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + + if m.state.Step <= StepWebdav { + if err := m.migrateWebdav(); err != nil { + return err + } + if err := m.updateStep(StepCompleted); err != nil { + return fmt.Errorf("failed to update step: %w", err) + } + } + m.l.Info("Migration completed successfully") + return nil +} + +func formatTime(t time.Time) time.Time { + newTime := time.UnixMilli(t.UnixMilli()) + return newTime +} diff --git a/models/dialects/dialect_sqlite.go b/application/migrator/model/dialects/dialect_sqlite.go similarity index 100% rename from models/dialects/dialect_sqlite.go rename to application/migrator/model/dialects/dialect_sqlite.go diff --git a/application/migrator/model/file.go b/application/migrator/model/file.go new file mode 100644 index 00000000..5f83f10d --- /dev/null +++ b/application/migrator/model/file.go @@ -0,0 +1,39 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// File 文件 +type File struct { + // 表字段 + gorm.Model + Name string `gorm:"unique_index:idx_only_one"` + SourceName string `gorm:"type:text"` + UserID uint `gorm:"index:user_id;unique_index:idx_only_one"` + Size uint64 + PicInfo string + FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"` + PolicyID uint + UploadSessionID *string `gorm:"index:session_id;unique_index:session_only_one"` + Metadata string `gorm:"type:text"` + + // 关联模型 + Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` + + // 数据库忽略字段 + Position string `gorm:"-"` + MetadataSerialized map[string]string `gorm:"-"` +} + +// Thumb related metadata +const ( + ThumbStatusNotExist = "" + ThumbStatusExist = "exist" + ThumbStatusNotAvailable = "not_available" + + ThumbStatusMetadataKey = "thumb_status" + ThumbSidecarMetadataKey = "thumb_sidecar" + + ChecksumMetadataKey = "webdav_checksum" +) diff --git a/application/migrator/model/folder.go b/application/migrator/model/folder.go new file mode 100644 index 00000000..746eecc3 --- /dev/null +++ b/application/migrator/model/folder.go @@ -0,0 +1,18 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Folder 目录 +type Folder struct { + // 表字段 + gorm.Model + Name string `gorm:"unique_index:idx_only_one_name"` + ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` + OwnerID uint `gorm:"index:owner_id"` + + // 数据库忽略字段 + Position string `gorm:"-"` + WebdavDstName string `gorm:"-"` +} diff --git a/models/group.go b/application/migrator/model/group.go similarity index 57% rename from models/group.go rename to application/migrator/model/group.go index 0abf21db..cec9ea71 100644 --- a/models/group.go +++ b/application/migrator/model/group.go @@ -1,7 +1,6 @@ package model import ( - "encoding/json" "github.com/jinzhu/gorm" ) @@ -37,48 +36,3 @@ type GroupOption struct { AdvanceDelete bool `json:"advance_delete,omitempty"` WebDAVProxy bool `json:"webdav_proxy,omitempty"` } - -// GetGroupByID 用ID获取用户组 -func GetGroupByID(ID interface{}) (Group, error) { - var group Group - result := DB.First(&group, ID) - return group, result.Error -} - -// AfterFind 找到用户组后的钩子,处理Policy列表 -func (group *Group) AfterFind() (err error) { - // 解析用户组策略列表 - if group.Policies != "" { - err = json.Unmarshal([]byte(group.Policies), &group.PolicyList) - } - if err != nil { - return err - } - - // 解析用户组设置 - if group.Options != "" { - err = json.Unmarshal([]byte(group.Options), &group.OptionsSerialized) - } - - return err -} - -// BeforeSave Save用户前的钩子 -func (group *Group) BeforeSave() (err error) { - err = group.SerializePolicyList() - return err -} - -// SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段 -// TODO 完善测试 -func (group *Group) SerializePolicyList() (err error) { - policies, err := json.Marshal(&group.PolicyList) - group.Policies = string(policies) - if err != nil { - return err - } - - optionsValue, err := json.Marshal(&group.OptionsSerialized) - group.Options = string(optionsValue) - return err -} diff --git a/application/migrator/model/init.go b/application/migrator/model/init.go new file mode 100644 index 00000000..46f8d362 --- /dev/null +++ b/application/migrator/model/init.go @@ -0,0 +1,91 @@ +package model + +import ( + "fmt" + "time" + + "github.com/jinzhu/gorm" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + _ "github.com/jinzhu/gorm/dialects/mssql" + _ "github.com/jinzhu/gorm/dialects/mysql" + _ "github.com/jinzhu/gorm/dialects/postgres" +) + +// DB 数据库链接单例 +var DB *gorm.DB + +// Init 初始化 MySQL 链接 +func Init() error { + var ( + db *gorm.DB + err error + confDBType string = conf.DatabaseConfig.Type + ) + + // 兼容已有配置中的 "sqlite3" 配置项 + if confDBType == "sqlite3" { + confDBType = "sqlite" + } + + switch confDBType { + case "UNSET", "sqlite": + // 未指定数据库或者明确指定为 sqlite 时,使用 SQLite 数据库 + db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile)) + case "postgres": + db, err = gorm.Open(confDBType, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable", + conf.DatabaseConfig.Host, + conf.DatabaseConfig.User, + conf.DatabaseConfig.Password, + conf.DatabaseConfig.Name, + conf.DatabaseConfig.Port)) + case "mysql", "mssql": + var host string + if conf.DatabaseConfig.UnixSocket { + host = fmt.Sprintf("unix(%s)", + conf.DatabaseConfig.Host) + } else { + host = fmt.Sprintf("(%s:%d)", + conf.DatabaseConfig.Host, + conf.DatabaseConfig.Port) + } + + db, err = gorm.Open(confDBType, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local", + conf.DatabaseConfig.User, + conf.DatabaseConfig.Password, + host, + conf.DatabaseConfig.Name, + conf.DatabaseConfig.Charset)) + default: + return fmt.Errorf("unsupported database type %q", confDBType) + } + + //db.SetLogger(util.Log()) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + // 处理表前缀 + gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string { + return conf.DatabaseConfig.TablePrefix + defaultTableName + } + + // Debug模式下,输出所有 SQL 日志 + db.LogMode(true) + + //设置连接池 + db.DB().SetMaxIdleConns(50) + if confDBType == "sqlite" || confDBType == "UNSET" { + db.DB().SetMaxOpenConns(1) + } else { + db.DB().SetMaxOpenConns(100) + } + + //超时 + db.DB().SetConnMaxLifetime(time.Second * 30) + + DB = db + + return nil +} diff --git a/models/node.go b/application/migrator/model/node.go similarity index 54% rename from models/node.go rename to application/migrator/model/node.go index 992a8280..5ce0e1db 100644 --- a/models/node.go +++ b/application/migrator/model/node.go @@ -1,7 +1,6 @@ package model import ( - "encoding/json" "github.com/jinzhu/gorm" ) @@ -50,42 +49,3 @@ const ( SlaveNodeType ModelType = iota MasterNodeType ) - -// GetNodeByID 用ID获取节点 -func GetNodeByID(ID interface{}) (Node, error) { - var node Node - result := DB.First(&node, ID) - return node, result.Error -} - -// GetNodesByStatus 根据给定状态获取节点 -func GetNodesByStatus(status ...NodeStatus) ([]Node, error) { - var nodes []Node - result := DB.Where("status in (?)", status).Find(&nodes) - return nodes, result.Error -} - -// AfterFind 找到节点后的钩子 -func (node *Node) AfterFind() (err error) { - // 解析离线下载设置到 Aria2OptionsSerialized - if node.Aria2Options != "" { - err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized) - } - - return err -} - -// BeforeSave Save策略前的钩子 -func (node *Node) BeforeSave() (err error) { - optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized) - node.Aria2Options = string(optionsValue) - return err -} - -// SetStatus 设置节点启用状态 -func (node *Node) SetStatus(status NodeStatus) error { - node.Status = status - return DB.Model(node).Updates(map[string]interface{}{ - "status": status, - }).Error -} diff --git a/application/migrator/model/policy.go b/application/migrator/model/policy.go new file mode 100644 index 00000000..fb192101 --- /dev/null +++ b/application/migrator/model/policy.go @@ -0,0 +1,62 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Policy 存储策略 +type Policy struct { + // 表字段 + gorm.Model + Name string + Type string + Server string + BucketName string + IsPrivate bool + BaseURL string + AccessKey string `gorm:"type:text"` + SecretKey string `gorm:"type:text"` + MaxSize uint64 + AutoRename bool + DirNameRule string + FileNameRule string + IsOriginLinkEnable bool + Options string `gorm:"type:text"` + + // 数据库忽略字段 + OptionsSerialized PolicyOption `gorm:"-"` + MasterID string `gorm:"-"` +} + +// PolicyOption 非公有的存储策略属性 +type PolicyOption struct { + // Upyun访问Token + Token string `json:"token"` + // 允许的文件扩展名 + FileType []string `json:"file_type"` + // MimeType + MimeType string `json:"mimetype"` + // OauthRedirect Oauth 重定向地址 + OauthRedirect string `json:"od_redirect,omitempty"` + // OdProxy Onedrive 反代地址 + OdProxy string `json:"od_proxy,omitempty"` + // OdDriver OneDrive 驱动器定位符 + OdDriver string `json:"od_driver,omitempty"` + // Region 区域代码 + Region string `json:"region,omitempty"` + // ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段 + ServerSideEndpoint string `json:"server_side_endpoint,omitempty"` + // 分片上传的分片大小 + ChunkSize uint64 `json:"chunk_size,omitempty"` + // 分片上传时是否需要预留空间 + PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"` + // 每秒对存储端的 API 请求上限 + TPSLimit float64 `json:"tps_limit,omitempty"` + // 每秒 API 请求爆发上限 + TPSLimitBurst int `json:"tps_limit_burst,omitempty"` + // Set this to `true` to force the request to use path-style addressing, + // i.e., `http://s3.amazonaws.com/BUCKET/KEY ` + S3ForcePathStyle bool `json:"s3_path_style"` + // File extensions that support thumbnail generation using native policy API. + ThumbExts []string `json:"thumb_exts,omitempty"` +} diff --git a/application/migrator/model/setting.go b/application/migrator/model/setting.go new file mode 100644 index 00000000..2dd57673 --- /dev/null +++ b/application/migrator/model/setting.go @@ -0,0 +1,13 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Setting 系统设置模型 +type Setting struct { + gorm.Model + Type string `gorm:"not null"` + Name string `gorm:"unique;not null;index:setting_key"` + Value string `gorm:"size:65535"` +} diff --git a/application/migrator/model/share.go b/application/migrator/model/share.go new file mode 100644 index 00000000..e9df02b9 --- /dev/null +++ b/application/migrator/model/share.go @@ -0,0 +1,27 @@ +package model + +import ( + "time" + + "github.com/jinzhu/gorm" +) + +// Share 分享模型 +type Share struct { + gorm.Model + Password string // 分享密码,空值为非加密分享 + IsDir bool // 原始资源是否为目录 + UserID uint // 创建用户ID + SourceID uint // 原始资源ID + Views int // 浏览数 + Downloads int // 下载数 + RemainDownloads int // 剩余下载配额,负值标识无限制 + Expires *time.Time // 过期时间,空值表示无过期时间 + PreviewEnabled bool // 是否允许直接预览 + SourceName string `gorm:"index:source"` // 用于搜索的字段 + + // 数据库忽略字段 + User User `gorm:"PRELOAD:false,association_autoupdate:false"` + File File `gorm:"PRELOAD:false,association_autoupdate:false"` + Folder Folder `gorm:"PRELOAD:false,association_autoupdate:false"` +} diff --git a/application/migrator/model/source_link.go b/application/migrator/model/source_link.go new file mode 100644 index 00000000..542a1bd3 --- /dev/null +++ b/application/migrator/model/source_link.go @@ -0,0 +1,16 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// SourceLink represent a shared file source link +type SourceLink struct { + gorm.Model + FileID uint // corresponding file ID + Name string // name of the file while creating the source link, for annotation + Downloads int // 下载数 + + // 关联模型 + File File `gorm:"save_associations:false:false"` +} diff --git a/application/migrator/model/tag.go b/application/migrator/model/tag.go new file mode 100644 index 00000000..400a8c7f --- /dev/null +++ b/application/migrator/model/tag.go @@ -0,0 +1,23 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Tag 用户自定义标签 +type Tag struct { + gorm.Model + Name string // 标签名 + Icon string // 图标标识 + Color string // 图标颜色 + Type int // 标签类型(文件分类/目录直达) + Expression string `gorm:"type:text"` // 搜索表表达式/直达路径 + UserID uint // 创建者ID +} + +const ( + // FileTagType 文件分类标签 + FileTagType = iota + // DirectoryLinkType 目录快捷方式标签 + DirectoryLinkType +) diff --git a/application/migrator/model/task.go b/application/migrator/model/task.go new file mode 100644 index 00000000..9b043f9c --- /dev/null +++ b/application/migrator/model/task.go @@ -0,0 +1,16 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Task 任务模型 +type Task struct { + gorm.Model + Status int // 任务状态 + Type int // 任务类型 + UserID uint // 发起者UID,0表示为系统发起 + Progress int // 进度 + Error string `gorm:"type:text"` // 错误信息 + Props string `gorm:"type:text"` // 任务属性 +} diff --git a/application/migrator/model/user.go b/application/migrator/model/user.go new file mode 100644 index 00000000..165b2cdd --- /dev/null +++ b/application/migrator/model/user.go @@ -0,0 +1,45 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +const ( + // Active 账户正常状态 + Active = iota + // NotActivicated 未激活 + NotActivicated + // Baned 被封禁 + Baned + // OveruseBaned 超额使用被封禁 + OveruseBaned +) + +// User 用户模型 +type User struct { + // 表字段 + gorm.Model + Email string `gorm:"type:varchar(100);unique_index"` + Nick string `gorm:"size:50"` + Password string `json:"-"` + Status int + GroupID uint + Storage uint64 + TwoFactor string + Avatar string + Options string `json:"-" gorm:"size:4294967295"` + Authn string `gorm:"size:4294967295"` + + // 关联模型 + Group Group `gorm:"save_associations:false:false"` + Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` + + // 数据库忽略字段 + OptionsSerialized UserOption `gorm:"-"` +} + +// UserOption 用户个性化配置字段 +type UserOption struct { + ProfileOff bool `json:"profile_off,omitempty"` + PreferredTheme string `json:"preferred_theme,omitempty"` +} diff --git a/application/migrator/model/webdav.go b/application/migrator/model/webdav.go new file mode 100644 index 00000000..72634378 --- /dev/null +++ b/application/migrator/model/webdav.go @@ -0,0 +1,16 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Webdav 应用账户 +type Webdav struct { + gorm.Model + Name string // 应用名称 + Password string `gorm:"unique_index:password_only_on"` // 应用密码 + UserID uint `gorm:"unique_index:password_only_on"` // 用户ID + Root string `gorm:"type:text"` // 根目录 + Readonly bool `gorm:"type:bool"` // 是否只读 + UseProxy bool `gorm:"type:bool"` // 是否进行反代 +} diff --git a/application/migrator/node.go b/application/migrator/node.go new file mode 100644 index 00000000..eac2db37 --- /dev/null +++ b/application/migrator/node.go @@ -0,0 +1,89 @@ +package migrator + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +func (m *Migrator) migrateNode() error { + m.l.Info("Migrating nodes...") + + var nodes []model.Node + if err := model.DB.Find(&nodes).Error; err != nil { + return fmt.Errorf("failed to list v3 nodes: %w", err) + } + + for _, n := range nodes { + nodeType := node.TypeSlave + nodeStatus := node.StatusSuspended + if n.Type == model.MasterNodeType { + nodeType = node.TypeMaster + } + if n.Status == model.NodeActive { + nodeStatus = node.StatusActive + } + + cap := &boolset.BooleanSet{} + settings := &types.NodeSetting{ + Provider: types.DownloaderProviderAria2, + } + + if n.Aria2Enabled { + boolset.Sets(map[types.NodeCapability]bool{ + types.NodeCapabilityRemoteDownload: true, + }, cap) + + aria2Options := &model.Aria2Option{} + if err := json.Unmarshal([]byte(n.Aria2Options), aria2Options); err != nil { + return fmt.Errorf("failed to unmarshal aria2 options: %w", err) + } + + downloaderOptions := map[string]any{} + if aria2Options.Options != "" { + if err := json.Unmarshal([]byte(aria2Options.Options), &downloaderOptions); err != nil { + return fmt.Errorf("failed to unmarshal aria2 options: %w", err) + } + } + + settings.Aria2Setting = &types.Aria2Setting{ + Server: aria2Options.Server, + Token: aria2Options.Token, + Options: downloaderOptions, + TempPath: aria2Options.TempPath, + } + } + + if n.Type == model.MasterNodeType { + boolset.Sets(map[types.NodeCapability]bool{ + types.NodeCapabilityExtractArchive: true, + types.NodeCapabilityCreateArchive: true, + }, cap) + } + + stm := m.v4client.Node.Create(). + SetRawID(int(n.ID)). + SetCreatedAt(formatTime(n.CreatedAt)). + SetUpdatedAt(formatTime(n.UpdatedAt)). + SetName(n.Name). + SetType(nodeType). + SetStatus(nodeStatus). + SetServer(n.Server). + SetSlaveKey(n.SlaveKey). + SetCapabilities(cap). + SetSettings(settings). + SetWeight(n.Rank) + + if err := stm.Exec(context.Background()); err != nil { + return fmt.Errorf("failed to create node %q: %w", n.Name, err) + } + + } + + return nil +} diff --git a/application/migrator/policy.go b/application/migrator/policy.go new file mode 100644 index 00000000..83c6c5e0 --- /dev/null +++ b/application/migrator/policy.go @@ -0,0 +1,192 @@ +package migrator + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" +) + +func (m *Migrator) migratePolicy() (map[int]bool, error) { + m.l.Info("Migrating storage policies...") + var policies []model.Policy + if err := model.DB.Find(&policies).Error; err != nil { + return nil, fmt.Errorf("failed to list v3 storage policies: %w", err) + } + + if m.state.LocalPolicyIDs == nil { + m.state.LocalPolicyIDs = make(map[int]bool) + } + + if m.state.PolicyIDs == nil { + m.state.PolicyIDs = make(map[int]bool) + } + + m.l.Info("Found %d v3 storage policies to be migrated.", len(policies)) + + // get thumb proxy settings + var ( + thumbProxySettings []model.Setting + thumbProxyEnabled bool + thumbProxyPolicy []int + ) + if err := model.DB.Where("name in (?)", []string{"thumb_proxy_enabled", "thumb_proxy_policy"}).Find(&thumbProxySettings).Error; err != nil { + m.l.Warning("Failed to list v3 thumb proxy settings: %w", err) + } + + tx, err := m.v4client.Tx(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to start transaction: %w", err) + } + + for _, s := range thumbProxySettings { + if s.Name == "thumb_proxy_enabled" { + thumbProxyEnabled = setting.IsTrueValue(s.Value) + } else if s.Name == "thumb_proxy_policy" { + if err := json.Unmarshal([]byte(s.Value), &thumbProxyPolicy); err != nil { + m.l.Warning("Failed to unmarshal v3 thumb proxy policy: %w", err) + } + } + } + + for _, policy := range policies { + m.l.Info("Migrating storage policy %q...", policy.Name) + if err := json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized); err != nil { + return nil, fmt.Errorf("failed to unmarshal options for policy %q: %w", policy.Name, err) + } + + settings := &types.PolicySetting{ + Token: policy.OptionsSerialized.Token, + FileType: policy.OptionsSerialized.FileType, + OauthRedirect: policy.OptionsSerialized.OauthRedirect, + OdDriver: policy.OptionsSerialized.OdDriver, + Region: policy.OptionsSerialized.Region, + ServerSideEndpoint: policy.OptionsSerialized.ServerSideEndpoint, + ChunkSize: int64(policy.OptionsSerialized.ChunkSize), + TPSLimit: policy.OptionsSerialized.TPSLimit, + TPSLimitBurst: policy.OptionsSerialized.TPSLimitBurst, + S3ForcePathStyle: policy.OptionsSerialized.S3ForcePathStyle, + ThumbExts: policy.OptionsSerialized.ThumbExts, + } + + if policy.Type == types.PolicyTypeOd { + settings.ThumbSupportAllExts = true + } else { + switch policy.Type { + case types.PolicyTypeCos: + settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"} + case types.PolicyTypeOss: + settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"} + case types.PolicyTypeUpyun: + settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"} + case types.PolicyTypeQiniu: + settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"} + case types.PolicyTypeRemote: + settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif"} + } + } + + if policy.Type != types.PolicyTypeOd && policy.BaseURL != "" { + settings.CustomProxy = true + settings.ProxyServer = policy.BaseURL + } else if policy.OptionsSerialized.OdProxy != "" { + settings.CustomProxy = true + settings.ProxyServer = policy.OptionsSerialized.OdProxy + } + + if policy.DirNameRule == "" { + policy.DirNameRule = "uploads/{uid}/{path}" + } + + if policy.Type == types.PolicyTypeCos { + settings.ChunkSize = 1024 * 1024 * 25 + } + + if thumbProxyEnabled && lo.Contains(thumbProxyPolicy, int(policy.ID)) { + settings.ThumbGeneratorProxy = true + } + + mustContain := []string{"{randomkey16}", "{randomkey8}", "{uuid}"} + hasRandomElement := false + for _, c := range mustContain { + if strings.Contains(policy.FileNameRule, c) { + hasRandomElement = true + break + } + } + if !hasRandomElement { + policy.FileNameRule = "{uid}_{randomkey8}_{originname}" + m.l.Warning("Storage policy %q has no random element in file name rule, using default file name rule.", policy.Name) + } + + stm := tx.StoragePolicy.Create(). + SetRawID(int(policy.ID)). + SetCreatedAt(formatTime(policy.CreatedAt)). + SetUpdatedAt(formatTime(policy.UpdatedAt)). + SetName(policy.Name). + SetType(policy.Type). + SetServer(policy.Server). + SetBucketName(policy.BucketName). + SetIsPrivate(policy.IsPrivate). + SetAccessKey(policy.AccessKey). + SetSecretKey(policy.SecretKey). + SetMaxSize(int64(policy.MaxSize)). + SetDirNameRule(policy.DirNameRule). + SetFileNameRule(policy.FileNameRule). + SetSettings(settings) + + if policy.Type == types.PolicyTypeRemote { + m.l.Info("Storage policy %q is remote, creating node for it...", policy.Name) + bs := &boolset.BooleanSet{} + n, err := tx.Node.Create(). + SetName(policy.Name). + SetStatus(node.StatusActive). + SetServer(policy.Server). + SetSlaveKey(policy.SecretKey). + SetType(node.TypeSlave). + SetCapabilities(bs). + SetSettings(&types.NodeSetting{ + Provider: types.DownloaderProviderAria2, + }). + Save(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to create node for storage policy %q: %w", policy.Name, err) + } + + stm.SetNodeID(n.ID) + } + + if _, err := stm.Save(context.Background()); err != nil { + return nil, fmt.Errorf("failed to create storage policy %q: %w", policy.Name, err) + } + + m.state.PolicyIDs[int(policy.ID)] = true + if policy.Type == types.PolicyTypeLocal { + m.state.LocalPolicyIDs[int(policy.ID)] = true + } + } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit transaction: %w", err) + } + + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting storage policy ID sequence for postgres...") + m.v4client.StoragePolicy.ExecContext(context.Background(), "SELECT SETVAL('storage_policies_id_seq', (SELECT MAX(id) FROM storage_policies))") + } + + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting node ID sequence for postgres...") + m.v4client.Node.ExecContext(context.Background(), "SELECT SETVAL('nodes_id_seq', (SELECT MAX(id) FROM nodes))") + } + + return m.state.PolicyIDs, nil +} diff --git a/application/migrator/settings.go b/application/migrator/settings.go new file mode 100644 index 00000000..db7e86f1 --- /dev/null +++ b/application/migrator/settings.go @@ -0,0 +1,213 @@ +package migrator + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/migrator/conf" + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" +) + +// TODO: +// 1. Policy thumb proxy migration + +type ( + settignMigrator func(allSettings map[string]string, name, value string) ([]settingMigrated, error) + settingMigrated struct { + name string + value string + } + // PackProduct 容量包商品 + PackProduct struct { + ID int64 `json:"id"` + Name string `json:"name"` + Size uint64 `json:"size"` + Time int64 `json:"time"` + Price int `json:"price"` + Score int `json:"score"` + } + GroupProducts struct { + ID int64 `json:"id"` + Name string `json:"name"` + GroupID uint `json:"group_id"` + Time int64 `json:"time"` + Price int `json:"price"` + Score int `json:"score"` + Des []string `json:"des"` + Highlight bool `json:"highlight"` + } +) + +var noopMigrator = func(allSettings map[string]string, name, value string) ([]settingMigrated, error) { + return nil, nil +} + +var migrators = map[string]settignMigrator{ + "siteKeywords": noopMigrator, + "over_used_template": noopMigrator, + "download_timeout": noopMigrator, + "preview_timeout": noopMigrator, + "doc_preview_timeout": noopMigrator, + "slave_node_retry": noopMigrator, + "slave_ping_interval": noopMigrator, + "slave_recover_interval": noopMigrator, + "slave_transfer_timeout": noopMigrator, + "onedrive_monitor_timeout": noopMigrator, + "onedrive_source_timeout": noopMigrator, + "share_download_session_timeout": noopMigrator, + "onedrive_callback_check": noopMigrator, + "mail_activation_template": noopMigrator, + "mail_reset_pwd_template": noopMigrator, + "appid": noopMigrator, + "appkey": noopMigrator, + "wechat_enabled": noopMigrator, + "wechat_appid": noopMigrator, + "wechat_mchid": noopMigrator, + "wechat_serial_no": noopMigrator, + "wechat_api_key": noopMigrator, + "wechat_pk_content": noopMigrator, + "hot_share_num": noopMigrator, + "defaultTheme": noopMigrator, + "theme_options": noopMigrator, + "max_worker_num": noopMigrator, + "max_parallel_transfer": noopMigrator, + "secret_key": noopMigrator, + "avatar_size_m": noopMigrator, + "avatar_size_s": noopMigrator, + "home_view_method": noopMigrator, + "share_view_method": noopMigrator, + "cron_recycle_upload_session": noopMigrator, + "captcha_type": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) { + if value == "tcaptcha" { + value = "normal" + } + return []settingMigrated{ + { + name: "captcha_type", + value: value, + }, + }, nil + }, + "captcha_TCaptcha_CaptchaAppId": noopMigrator, + "captcha_TCaptcha_AppSecretKey": noopMigrator, + "captcha_TCaptcha_SecretId": noopMigrator, + "captcha_TCaptcha_SecretKey": noopMigrator, + "thumb_file_suffix": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) { + return []settingMigrated{ + { + name: "thumb_entity_suffix", + value: value, + }, + }, nil + }, + "thumb_max_src_size": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) { + return []settingMigrated{ + { + name: "thumb_music_cover_max_size", + value: value, + }, + { + name: "thumb_libreoffice_max_size", + value: value, + }, + { + name: "thumb_ffmpeg_max_size", + value: value, + }, + { + name: "thumb_vips_max_size", + value: value, + }, + { + name: "thumb_builtin_max_size", + value: value, + }, + }, nil + }, + "initial_files": noopMigrator, + "office_preview_service": noopMigrator, + "phone_required": noopMigrator, + "phone_enabled": noopMigrator, + "wopi_session_timeout": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) { + return []settingMigrated{ + { + name: "viewer_session_timeout", + value: value, + }, + }, nil + }, + "custom_payment_enabled": noopMigrator, + "custom_payment_endpoint": noopMigrator, + "custom_payment_secret": noopMigrator, + "custom_payment_name": noopMigrator, +} + +func (m *Migrator) migrateSettings() error { + m.l.Info("Migrating settings...") + // 1. List all settings + var settings []model.Setting + if err := model.DB.Find(&settings).Error; err != nil { + return fmt.Errorf("failed to list v3 settings: %w", err) + } + + m.l.Info("Found %d v3 setting pairs to be migrated.", len(settings)) + + allSettings := make(map[string]string) + for _, s := range settings { + allSettings[s.Name] = s.Value + } + + migratedSettings := make([]settingMigrated, 0) + for _, s := range settings { + if s.Name == "thumb_file_suffix" { + m.state.ThumbSuffix = s.Value + } + if s.Name == "avatar_path" { + m.state.V3AvatarPath = s.Value + } + migrator, ok := migrators[s.Name] + if ok { + newSettings, err := migrator(allSettings, s.Name, s.Value) + if err != nil { + return fmt.Errorf("failed to migrate setting %q: %w", s.Name, err) + } + migratedSettings = append(migratedSettings, newSettings...) + } else { + migratedSettings = append(migratedSettings, settingMigrated{ + name: s.Name, + value: s.Value, + }) + } + } + + tx, err := m.v4client.Tx(context.Background()) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + + // Insert hash_id_salt + if conf.SystemConfig.HashIDSalt != "" { + if err := tx.Setting.Create().SetName("hash_id_salt").SetValue(conf.SystemConfig.HashIDSalt).Exec(context.Background()); err != nil { + if err := tx.Rollback(); err != nil { + return fmt.Errorf("failed to rollback transaction: %w", err) + } + return fmt.Errorf("failed to create setting hash_id_salt: %w", err) + } + } else { + return fmt.Errorf("hash ID salt is not set, please set it from v3 conf file") + } + + for _, s := range migratedSettings { + if err := tx.Setting.Create().SetName(s.name).SetValue(s.value).Exec(context.Background()); err != nil { + if err := tx.Rollback(); err != nil { + return fmt.Errorf("failed to rollback transaction: %w", err) + } + return fmt.Errorf("failed to create setting %q: %w", s.name, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} diff --git a/application/migrator/share.go b/application/migrator/share.go new file mode 100644 index 00000000..8e5900d3 --- /dev/null +++ b/application/migrator/share.go @@ -0,0 +1,102 @@ +package migrator + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +func (m *Migrator) migrateShare() error { + m.l.Info("Migrating shares...") + batchSize := 1000 + offset := m.state.ShareOffset + ctx := context.Background() + + if offset > 0 { + m.l.Info("Resuming share migration from offset %d", offset) + } + + for { + m.l.Info("Migrating shares with offset %d", offset) + var shares []model.Share + if err := model.DB.Limit(batchSize).Offset(offset).Find(&shares).Error; err != nil { + return fmt.Errorf("failed to list v3 shares: %w", err) + } + + if len(shares) == 0 { + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting share ID sequence for postgres...") + m.v4client.Share.ExecContext(ctx, "SELECT SETVAL('shares_id_seq', (SELECT MAX(id) FROM shares))") + } + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, s := range shares { + sourceId := int(s.SourceID) + if !s.IsDir { + sourceId += m.state.LastFolderID + } + + // check if file exists + _, err = tx.File.Query().Where(file.ID(sourceId)).First(ctx) + if err != nil { + m.l.Warning("File %d not found, skipping share %d", sourceId, s.ID) + continue + } + + // check if user exist + if _, ok := m.state.UserIDs[int(s.UserID)]; !ok { + m.l.Warning("User %d not found, skipping share %d", s.UserID, s.ID) + continue + } + + stm := tx.Share.Create(). + SetCreatedAt(formatTime(s.CreatedAt)). + SetUpdatedAt(formatTime(s.UpdatedAt)). + SetViews(s.Views). + SetRawID(int(s.ID)). + SetDownloads(s.Downloads). + SetFileID(sourceId). + SetUserID(int(s.UserID)) + + if s.Password != "" { + stm.SetPassword(s.Password) + } + + if s.Expires != nil { + stm.SetNillableExpires(s.Expires) + } + + if s.RemainDownloads >= 0 { + stm.SetRemainDownloads(s.RemainDownloads) + } + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to create share %d: %w", s.ID, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + offset += batchSize + m.state.ShareOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after share batch: %s", err) + } else { + m.l.Info("Saved migration state after processing this batch") + } + } + return nil +} diff --git a/application/migrator/user.go b/application/migrator/user.go new file mode 100644 index 00000000..273d8f09 --- /dev/null +++ b/application/migrator/user.go @@ -0,0 +1,109 @@ +package migrator + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +func (m *Migrator) migrateUser() error { + m.l.Info("Migrating users...") + batchSize := 1000 + // Start from the saved offset if available + offset := m.state.UserOffset + ctx := context.Background() + if m.state.UserIDs == nil { + m.state.UserIDs = make(map[int]bool) + } + + // If we're resuming, load existing user IDs + if len(m.state.UserIDs) > 0 { + m.l.Info("Resuming user migration from offset %d, %d users already migrated", offset, len(m.state.UserIDs)) + } + + for { + m.l.Info("Migrating users with offset %d", offset) + var users []model.User + if err := model.DB.Limit(batchSize).Offset(offset).Find(&users).Error; err != nil { + return fmt.Errorf("failed to list v3 users: %w", err) + } + + if len(users) == 0 { + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting user ID sequence for postgres...") + m.v4client.User.ExecContext(ctx, "SELECT SETVAL('users_id_seq', (SELECT MAX(id) FROM users))") + } + break + } + + tx, err := m.v4client.Tx(context.Background()) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, u := range users { + userStatus := user.StatusActive + switch u.Status { + case model.Active: + userStatus = user.StatusActive + case model.NotActivicated: + userStatus = user.StatusInactive + case model.Baned: + userStatus = user.StatusManualBanned + case model.OveruseBaned: + userStatus = user.StatusSysBanned + } + + setting := &types.UserSetting{ + VersionRetention: true, + VersionRetentionMax: 10, + } + + stm := tx.User.Create(). + SetRawID(int(u.ID)). + SetCreatedAt(formatTime(u.CreatedAt)). + SetUpdatedAt(formatTime(u.UpdatedAt)). + SetEmail(u.Email). + SetNick(u.Nick). + SetStatus(userStatus). + SetStorage(int64(u.Storage)). + SetGroupID(int(u.GroupID)). + SetSettings(setting). + SetPassword(u.Password) + + if u.TwoFactor != "" { + stm.SetTwoFactorSecret(u.TwoFactor) + } + + if u.Avatar != "" { + stm.SetAvatar(u.Avatar) + } + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to create user %d: %w", u.ID, err) + } + + m.state.UserIDs[int(u.ID)] = true + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + // Update the offset in state and save after each batch + offset += batchSize + m.state.UserOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after user batch: %s", err) + } else { + m.l.Info("Saved migration state after processing %d users", offset) + } + } + + return nil +} diff --git a/application/migrator/webdav.go b/application/migrator/webdav.go new file mode 100644 index 00000000..156543f5 --- /dev/null +++ b/application/migrator/webdav.go @@ -0,0 +1,93 @@ +package migrator + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/migrator/model" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +func (m *Migrator) migrateWebdav() error { + m.l.Info("Migrating webdav accounts...") + + batchSize := 1000 + offset := m.state.WebdavOffset + ctx := context.Background() + + if m.state.WebdavOffset > 0 { + m.l.Info("Resuming webdav migration from offset %d", offset) + } + + for { + m.l.Info("Migrating webdav accounts with offset %d", offset) + var webdavAccounts []model.Webdav + if err := model.DB.Limit(batchSize).Offset(offset).Find(&webdavAccounts).Error; err != nil { + return fmt.Errorf("failed to list v3 webdav accounts: %w", err) + } + + if len(webdavAccounts) == 0 { + if m.dep.ConfigProvider().Database().Type == conf.PostgresDB { + m.l.Info("Resetting webdav account ID sequence for postgres...") + m.v4client.DavAccount.ExecContext(ctx, "SELECT SETVAL('dav_accounts_id_seq', (SELECT MAX(id) FROM dav_accounts))") + } + break + } + + tx, err := m.v4client.Tx(ctx) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to start transaction: %w", err) + } + + for _, webdavAccount := range webdavAccounts { + if _, ok := m.state.UserIDs[int(webdavAccount.UserID)]; !ok { + m.l.Warning("User %d not found, skipping webdav account %d", webdavAccount.UserID, webdavAccount.ID) + continue + } + + props := types.DavAccountProps{} + options := boolset.BooleanSet{} + + if webdavAccount.Readonly { + boolset.Set(int(types.DavAccountReadOnly), true, &options) + } + + if webdavAccount.UseProxy { + boolset.Set(int(types.DavAccountProxy), true, &options) + } + + stm := tx.DavAccount.Create(). + SetCreatedAt(formatTime(webdavAccount.CreatedAt)). + SetUpdatedAt(formatTime(webdavAccount.UpdatedAt)). + SetRawID(int(webdavAccount.ID)). + SetName(webdavAccount.Name). + SetURI("cloudreve://my" + webdavAccount.Root). + SetPassword(webdavAccount.Password). + SetProps(&props). + SetOptions(&options). + SetOwnerID(int(webdavAccount.UserID)) + + if _, err := stm.Save(ctx); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to create webdav account %d: %w", webdavAccount.ID, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + offset += batchSize + m.state.WebdavOffset = offset + if err := m.saveState(); err != nil { + m.l.Warning("Failed to save state after webdav batch: %s", err) + } else { + m.l.Info("Saved migration state after processing this batch") + } + } + + return nil +} diff --git a/bootstrap/embed.go b/application/statics/embed.go similarity index 99% rename from bootstrap/embed.go rename to application/statics/embed.go index 71f75672..e502d8d8 100644 --- a/bootstrap/embed.go +++ b/application/statics/embed.go @@ -126,7 +126,7 @@ // To support tools that analyze Go packages, the patterns found in //go:embed lines // are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns, // and XTestEmbedPatterns fields in the “go help list” output. -package bootstrap +package statics import ( "errors" diff --git a/application/statics/statics.go b/application/statics/statics.go new file mode 100644 index 00000000..e53642e7 --- /dev/null +++ b/application/statics/statics.go @@ -0,0 +1,206 @@ +package statics + +import ( + "archive/zip" + "bufio" + "crypto/sha256" + _ "embed" + "encoding/json" + "fmt" + "io" + "io/fs" + "net/http" + "path/filepath" + "sort" + "strings" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-contrib/static" +) + +const StaticFolder = "statics" + +//go:embed assets.zip +var zipContent string + +type GinFS struct { + FS http.FileSystem +} + +type version struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// Open 打开文件 +func (b *GinFS) Open(name string) (http.File, error) { + return b.FS.Open(name) +} + +// Exists 文件是否存在 +func (b *GinFS) Exists(prefix string, filepath string) bool { + if _, err := b.FS.Open(filepath); err != nil { + return false + } + return true +} + +// NewServerStaticFS 初始化静态资源文件 +func NewServerStaticFS(l logging.Logger, statics fs.FS, isPro bool) (static.ServeFileSystem, error) { + var staticFS static.ServeFileSystem + if util.Exists(util.DataPath(StaticFolder)) { + l.Info("Folder with %q already exists, it will be used to serve static files.", util.DataPath(StaticFolder)) + staticFS = static.LocalFile(util.DataPath(StaticFolder), false) + } else { + // 初始化静态资源 + embedFS, err := fs.Sub(statics, "assets/build") + if err != nil { + return nil, fmt.Errorf("failed to initialize static resources: %w", err) + } + + staticFS = &GinFS{ + FS: http.FS(embedFS), + } + } + // 检查静态资源的版本 + f, err := staticFS.Open("version.json") + if err != nil { + l.Warning("Missing version identifier file in static resources, please delete \"statics\" folder and rebuild it.") + return staticFS, nil + } + + b, err := io.ReadAll(f) + if err != nil { + l.Warning("Failed to read version identifier file in static resources, please delete \"statics\" folder and rebuild it.") + return staticFS, nil + } + + var v version + if err := json.Unmarshal(b, &v); err != nil { + l.Warning("Failed to parse version identifier file in static resources: %s", err) + return staticFS, nil + } + + staticName := "cloudreve-frontend" + if isPro { + staticName += "-pro" + } + + if v.Name != staticName { + l.Panic("Static resource version mismatch, please delete \"statics\" folder and rebuild it.") + } + + if v.Version != constants.BackendVersion { + l.Panic("Static resource version mismatch [Current %s, Desired: %s],please delete \"statics\" folder and rebuild it.", v.Version, constants.BackendVersion) + } + + return staticFS, nil +} + +func NewStaticFS(l logging.Logger) fs.FS { + zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent))) + if err != nil { + l.Panic("Static resource is not a valid zip file: %s", err) + } + + var files []file + err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("cannot walk into %q: %w", path, err) + } + + if path == "." { + return nil + } + + var f file + if d.IsDir() { + f.name = path + "/" + } else { + f.name = path + + rc, err := zipReader.Open(path) + if err != nil { + return fmt.Errorf("canot open %q: %w", path, err) + } + defer rc.Close() + + data, err := io.ReadAll(rc) + if err != nil { + return fmt.Errorf("cannot read %q: %w", path, err) + } + + f.data = string(data) + + hash := sha256.Sum256(data) + for i := range f.hash { + f.hash[i] = ^hash[i] + } + } + files = append(files, f) + return nil + }) + if err != nil { + l.Panic("Failed to initialize static resources: %s", err) + } + + sort.Slice(files, func(i, j int) bool { + fi, fj := files[i], files[j] + di, ei, _ := split(fi.name) + dj, ej, _ := split(fj.name) + + if di != dj { + return di < dj + } + return ei < ej + }) + + var embedFS FS + embedFS.files = &files + return embedFS +} + +// Eject 抽离内置静态资源 +func Eject(l logging.Logger, statics fs.FS) error { + // 初始化静态资源 + embedFS, err := fs.Sub(statics, "assets/build") + if err != nil { + l.Panic("Failed to initialize static resources: %s", err) + } + + var walk func(relPath string, d fs.DirEntry, err error) error + walk = func(relPath string, d fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("failed to read info of %q: %s, skipping...", relPath, err) + } + + if !d.IsDir() { + // 写入文件 + dst := util.DataPath(filepath.Join(StaticFolder, relPath)) + out, err := util.CreatNestedFile(dst) + defer out.Close() + + if err != nil { + return fmt.Errorf("failed to create file %q: %s, skipping...", dst, err) + } + + l.Info("Ejecting %q...", dst) + obj, _ := embedFS.Open(relPath) + if _, err := io.Copy(out, bufio.NewReader(obj)); err != nil { + return fmt.Errorf("cannot write file %q: %s, skipping...", relPath, err) + } + } + return nil + } + + // util.Log().Info("开始导出内置静态资源...") + err = fs.WalkDir(embedFS, ".", walk) + if err != nil { + return fmt.Errorf("failed to eject static resources: %w", err) + } + + l.Info("Finish ejecting static resources.") + return nil +} diff --git a/assets b/assets index 5d4d01a7..b485bf29 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit 5d4d01a797a1ba2d6866799684bf05de20006e31 +Subproject commit b485bf297974cbe4834d2e8e744ae7b7e5b2ad39 diff --git a/assets.zip b/assets.zip deleted file mode 100644 index 15cb0ecb..00000000 Binary files a/assets.zip and /dev/null differ diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000..c4325751 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,49 @@ +trigger: + tags: + include: + - '*' +variables: + GO_VERSION: "1.23.6" + NODE_VERSION: "22.x" + DOCKER_BUILDKIT: 1 + +pool: + vmImage: ubuntu-latest + +jobs: + - job: Release + steps: + - checkout: self + submodules: true + persistCredentials: true + - task: NodeTool@0 + inputs: + versionSpec: '$(NODE_VERSION)' + displayName: 'Install Node.js' + - task: GoTool@0 + inputs: + version: "$(GO_VERSION)" + displayName: Install Go + - task: Docker@2 + inputs: + containerRegistry: "PRO ACR" + command: "login" + addPipelineData: false + addBaseImageData: false + - task: CmdLine@2 + displayName: "Install multiarch/qemu-user-static" + inputs: + script: | + docker run --rm --privileged multiarch/qemu-user-static --reset -p yes + - task: goreleaser@0 + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + inputs: + version: "latest" + distribution: "goreleaser" + workdir: "$(Build.SourcesDirectory)" + args: "release --timeout 60m" + env: + AWS_ACCESS_KEY_ID: $(AWS_ACCESS_KEY_ID) + AWS_SECRET_ACCESS_KEY: $(AWS_SECRET_ACCESS_KEY) + GITHUB_TOKEN: $(GITHUB_TOKEN) + \ No newline at end of file diff --git a/bootstrap/app.go b/bootstrap/app.go deleted file mode 100644 index 29065268..00000000 --- a/bootstrap/app.go +++ /dev/null @@ -1,58 +0,0 @@ -package bootstrap - -import ( - "encoding/json" - "fmt" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/hashicorp/go-version" -) - -// InitApplication 初始化应用常量 -func InitApplication() { - fmt.Print(` - ___ _ _ - / __\ | ___ _ _ __| |_ __ _____ _____ - / / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \ -/ /___| | (_) | |_| | (_| | | | __/\ V / __/ -\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___| - - V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Pro=` + conf.IsPro + ` -================================================ - -`) - go CheckUpdate() -} - -type GitHubRelease struct { - URL string `json:"html_url"` - Name string `json:"name"` - Tag string `json:"tag_name"` -} - -// CheckUpdate 检查更新 -func CheckUpdate() { - client := request.NewClient() - res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse() - if err != nil { - util.Log().Warning("更新检查失败, %s", err) - return - } - - var list []GitHubRelease - if err := json.Unmarshal([]byte(res), &list); err != nil { - util.Log().Warning("更新检查失败, %s", err) - return - } - - if len(list) > 0 { - present, err1 := version.NewVersion(conf.BackendVersion) - latest, err2 := version.NewVersion(list[0].Tag) - if err1 == nil && err2 == nil && latest.GreaterThan(present) { - util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL) - } - } - -} diff --git a/bootstrap/fs.go b/bootstrap/fs.go deleted file mode 100644 index a82396c2..00000000 --- a/bootstrap/fs.go +++ /dev/null @@ -1,75 +0,0 @@ -package bootstrap - -import ( - "archive/zip" - "crypto/sha256" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/pkg/errors" - "io" - "io/fs" - "sort" - "strings" -) - -func NewFS(zipContent string) fs.FS { - zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent))) - if err != nil { - util.Log().Panic("Static resource is not a valid zip file: %s", err) - } - - var files []file - err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return errors.Errorf("无法获取[%s]的信息, %s, 跳过...", path, err) - } - - if path == "." { - return nil - } - - var f file - if d.IsDir() { - f.name = path + "/" - } else { - f.name = path - - rc, err := zipReader.Open(path) - if err != nil { - return errors.Errorf("无法打开文件[%s], %s, 跳过...", path, err) - } - defer rc.Close() - - data, err := io.ReadAll(rc) - if err != nil { - return errors.Errorf("无法读取文件[%s], %s, 跳过...", path, err) - } - - f.data = string(data) - - hash := sha256.Sum256(data) - for i := range f.hash { - f.hash[i] = ^hash[i] - } - } - files = append(files, f) - return nil - }) - if err != nil { - util.Log().Panic("初始化静态资源失败: %s", err) - } - - sort.Slice(files, func(i, j int) bool { - fi, fj := files[i], files[j] - di, ei, _ := split(fi.name) - dj, ej, _ := split(fj.name) - - if di != dj { - return di < dj - } - return ei < ej - }) - - var embedFS FS - embedFS.files = &files - return embedFS -} diff --git a/bootstrap/init.go b/bootstrap/init.go deleted file mode 100644 index e5f28000..00000000 --- a/bootstrap/init.go +++ /dev/null @@ -1,132 +0,0 @@ -package bootstrap - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/models/scripts" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/crontab" - "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/gin-gonic/gin" - "io/fs" - "path/filepath" -) - -// Init 初始化启动 -func Init(path string, statics fs.FS) { - InitApplication() - conf.Init(path) - // Debug 关闭时,切换为生产模式 - if !conf.SystemConfig.Debug { - gin.SetMode(gin.ReleaseMode) - } - - dependencies := []struct { - mode string - factory func() - }{ - { - "both", - func() { - scripts.Init() - }, - }, - { - "both", - func() { - cache.Init() - }, - }, - { - "slave", - func() { - model.InitSlaveDefaults() - }, - }, - { - "slave", - func() { - cache.InitSlaveOverwrites() - }, - }, - { - "master", - func() { - model.Init() - }, - }, - { - "both", - func() { - cache.Restore(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)) - }, - }, - { - "both", - func() { - task.Init() - }, - }, - { - "master", - func() { - cluster.Init() - }, - }, - { - "master", - func() { - aria2.Init(false, cluster.Default, mq.GlobalMQ) - }, - }, - { - "master", - func() { - email.Init() - }, - }, - { - "master", - func() { - crontab.Init() - }, - }, - { - "master", - func() { - InitStatic(statics) - }, - }, - { - "slave", - func() { - cluster.InitController() - }, - }, - { - "both", - func() { - auth.Init() - }, - }, - { - "master", - func() { - wopi.Init() - }, - }, - } - - for _, dependency := range dependencies { - if dependency.mode == conf.SystemConfig.Mode || dependency.mode == "both" { - dependency.factory() - } - } - -} diff --git a/bootstrap/script.go b/bootstrap/script.go deleted file mode 100644 index 6f0ac928..00000000 --- a/bootstrap/script.go +++ /dev/null @@ -1,18 +0,0 @@ -package bootstrap - -import ( - "context" - "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -func RunScript(name string) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if err := invoker.RunDBScript(name, ctx); err != nil { - util.Log().Error("Failed to execute database script: %s", err) - return - } - - util.Log().Info("Finish executing database script %q.", name) -} diff --git a/bootstrap/static.go b/bootstrap/static.go deleted file mode 100644 index 233e22af..00000000 --- a/bootstrap/static.go +++ /dev/null @@ -1,136 +0,0 @@ -package bootstrap - -import ( - "bufio" - "encoding/json" - "io" - "io/fs" - "net/http" - "path/filepath" - - "github.com/pkg/errors" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - - "github.com/gin-contrib/static" -) - -const StaticFolder = "statics" - -type GinFS struct { - FS http.FileSystem -} - -type staticVersion struct { - Name string `json:"name"` - Version string `json:"version"` -} - -// StaticFS 内置静态文件资源 -var StaticFS static.ServeFileSystem - -// Open 打开文件 -func (b *GinFS) Open(name string) (http.File, error) { - return b.FS.Open(name) -} - -// Exists 文件是否存在 -func (b *GinFS) Exists(prefix string, filepath string) bool { - if _, err := b.FS.Open(filepath); err != nil { - return false - } - return true -} - -// InitStatic 初始化静态资源文件 -func InitStatic(statics fs.FS) { - if util.Exists(util.RelativePath(StaticFolder)) { - util.Log().Info("Folder with name \"statics\" already exists, it will be used to serve static files.") - StaticFS = static.LocalFile(util.RelativePath("statics"), false) - } else { - // 初始化静态资源 - embedFS, err := fs.Sub(statics, "assets/build") - if err != nil { - util.Log().Panic("Failed to initialize static resources: %s", err) - } - - StaticFS = &GinFS{ - FS: http.FS(embedFS), - } - } - // 检查静态资源的版本 - f, err := StaticFS.Open("version.json") - if err != nil { - util.Log().Warning("Missing version identifier file in static resources, please delete \"statics\" folder and rebuild it.") - return - } - - b, err := io.ReadAll(f) - if err != nil { - util.Log().Warning("Failed to read version identifier file in static resources, please delete \"statics\" folder and rebuild it.") - return - } - - var v staticVersion - if err := json.Unmarshal(b, &v); err != nil { - util.Log().Warning("Failed to parse version identifier file in static resources: %s", err) - return - } - - staticName := "cloudreve-frontend" - if conf.IsPro == "true" { - staticName += "-pro" - } - - if v.Name != staticName { - util.Log().Warning("Static resource version mismatch, please delete \"statics\" folder and rebuild it.") - return - } - - if v.Version != conf.RequiredStaticVersion { - util.Log().Warning("Static resource version mismatch [Current %s, Desired: %s],please delete \"statics\" folder and rebuild it.", v.Version, conf.RequiredStaticVersion) - return - } -} - -// Eject 抽离内置静态资源 -func Eject(statics fs.FS) { - // 初始化静态资源 - embedFS, err := fs.Sub(statics, "assets/build") - if err != nil { - util.Log().Panic("Failed to initialize static resources: %s", err) - } - - var walk func(relPath string, d fs.DirEntry, err error) error - walk = func(relPath string, d fs.DirEntry, err error) error { - if err != nil { - return errors.Errorf("Failed to read info of %q: %s, skipping...", relPath, err) - } - - if !d.IsDir() { - // 写入文件 - out, err := util.CreatNestedFile(filepath.Join(util.RelativePath(""), StaticFolder, relPath)) - defer out.Close() - - if err != nil { - return errors.Errorf("Failed to create file %q: %s, skipping...", relPath, err) - } - - util.Log().Info("Ejecting %q...", relPath) - obj, _ := embedFS.Open(relPath) - if _, err := io.Copy(out, bufio.NewReader(obj)); err != nil { - return errors.Errorf("Cannot write file %q: %s, skipping...", relPath, err) - } - } - return nil - } - - // util.Log().Info("开始导出内置静态资源...") - err = fs.WalkDir(embedFS, ".", walk) - if err != nil { - util.Log().Error("Error occurs while ejecting static resources: %s", err) - return - } - util.Log().Info("Finish ejecting static resources.") -} diff --git a/cmd/eject.go b/cmd/eject.go new file mode 100644 index 00000000..93acb585 --- /dev/null +++ b/cmd/eject.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/application/statics" + "github.com/spf13/cobra" + "os" +) + +func init() { + rootCmd.AddCommand(ejectCmd) +} + +var ejectCmd = &cobra.Command{ + Use: "eject", + Short: "Eject all embedded static files", + Run: func(cmd *cobra.Command, args []string) { + dep := dependency.NewDependency( + dependency.WithConfigPath(confPath), + dependency.WithProFlag(constants.IsPro == "true"), + ) + logger := dep.Logger() + + if err := statics.Eject(dep.Logger(), dep.Statics()); err != nil { + logger.Error("Failed to eject static files: %s", err) + os.Exit(1) + } + }, +} diff --git a/cmd/migrate.go b/cmd/migrate.go new file mode 100644 index 00000000..a84ed87f --- /dev/null +++ b/cmd/migrate.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "os" + "path/filepath" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/application/migrator" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/spf13/cobra" +) + +var ( + v3ConfPath string + forceReset bool +) + +func init() { + rootCmd.AddCommand(migrateCmd) + migrateCmd.PersistentFlags().StringVar(&v3ConfPath, "v3-conf", "", "Path to the v3 config file") + migrateCmd.PersistentFlags().BoolVar(&forceReset, "force-reset", false, "Force reset migration state and start from beginning") +} + +var migrateCmd = &cobra.Command{ + Use: "migrate", + Short: "Migrate from v3 to v4", + Run: func(cmd *cobra.Command, args []string) { + dep := dependency.NewDependency( + dependency.WithConfigPath(confPath), + dependency.WithRequiredDbVersion(constants.BackendVersion), + dependency.WithProFlag(constants.IsPro == "true"), + ) + logger := dep.Logger() + logger.Info("Migrating from v3 to v4...") + + if v3ConfPath == "" { + logger.Error("v3 config file is required, please use -v3-conf to specify the path.") + os.Exit(1) + } + + // Check if state file exists and warn about resuming + stateFilePath := filepath.Join(filepath.Dir(v3ConfPath), "migration_state.json") + if util.Exists(stateFilePath) && !forceReset { + logger.Info("Found existing migration state file at %s. Migration will resume from the last successful step.", stateFilePath) + logger.Info("If you want to start migration from the beginning, please use --force-reset flag.") + } else if forceReset && util.Exists(stateFilePath) { + logger.Info("Force resetting migration state. Will start from the beginning.") + if err := os.Remove(stateFilePath); err != nil { + logger.Error("Failed to remove migration state file: %s", err) + os.Exit(1) + } + } + + migrator, err := migrator.NewMigrator(dep, v3ConfPath) + if err != nil { + logger.Error("Failed to create migrator: %s", err) + os.Exit(1) + } + + if err := migrator.Migrate(); err != nil { + logger.Error("Failed to migrate: %s", err) + logger.Info("Migration failed but state has been saved. You can retry with the same command to resume from the last successful step.") + os.Exit(1) + } + + logger.Info("Migration from v3 to v4 completed successfully.") + }, +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 00000000..0b7b2ff9 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "os" +) + +var ( + confPath string +) + +func init() { + rootCmd.PersistentFlags().StringVarP(&confPath, "conf", "c", util.DataPath("conf.ini"), "Path to the config file") + rootCmd.PersistentFlags().BoolVarP(&util.UseWorkingDir, "use-working-dir", "w", false, "Use working directory, instead of executable directory") +} + +var rootCmd = &cobra.Command{ + Use: "cloudreve", + Short: "Cloudreve is a server-side self-hosted cloud storage platform", + Long: `Self-hosted file management and sharing system, supports multiple storage providers. +Complete documentation is available at https://docs.cloudreve.org/`, + Run: func(cmd *cobra.Command, args []string) { + // Do Stuff Here + }, +} + +func Execute() { + cmd, _, err := rootCmd.Find(os.Args[1:]) + // redirect to default server cmd if no cmd is given + if err == nil && cmd.Use == rootCmd.Use && cmd.Flags().Parse(os.Args[1:]) != pflag.ErrHelp { + args := append([]string{"server"}, os.Args[1:]...) + rootCmd.SetArgs(args) + } + + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 00000000..1713a31e --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "os" + "os/signal" + "syscall" + + "github.com/cloudreve/Cloudreve/v4/application" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/spf13/cobra" +) + +var ( + licenseKey string +) + +func init() { + rootCmd.AddCommand(serverCmd) + serverCmd.PersistentFlags().StringVarP(&licenseKey, "license-key", "l", "", "License key of your Cloudreve Pro") +} + +var serverCmd = &cobra.Command{ + Use: "server", + Short: "Start a Cloudreve server with the given config file", + Run: func(cmd *cobra.Command, args []string) { + dep := dependency.NewDependency( + dependency.WithConfigPath(confPath), + dependency.WithProFlag(constants.IsProBool), + dependency.WithRequiredDbVersion(constants.BackendVersion), + dependency.WithLicenseKey(licenseKey), + ) + server := application.NewServer(dep) + logger := dep.Logger() + + server.PrintBanner() + + // Graceful shutdown after received signal. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go shutdown(sigChan, logger, server) + + if err := server.Start(); err != nil { + logger.Error("Failed to start server: %s", err) + os.Exit(1) + } + + defer func() { + <-sigChan + }() + }, +} + +func shutdown(sigChan chan os.Signal, logger logging.Logger, server application.Server) { + sig := <-sigChan + logger.Info("Signal %s received, shutting down server...", sig) + server.Close() + close(sigChan) +} diff --git a/docker-compose.yml b/docker-compose.yml index cc94bef7..129d26c4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,45 +1,40 @@ -version: "3.8" services: - redis: - container_name: redis - image: bitnami/redis:latest - restart: unless-stopped + pro: + image: cloudreve.azurecr.io/cloudreve/pro:latest + container_name: cloudreve-pro-backend + depends_on: + - postgresql + - redis + restart: always + ports: + - 5212:5212 environment: - - ALLOW_EMPTY_PASSWORD=yes + - CR_CONF_Database.Type=postgres + - CR_CONF_Database.Host=postgresql + - CR_CONF_Database.User=cloudreve + - CR_CONF_Database.Name=cloudreve + - CR_CONF_Database.Port=5432 + - CR_CONF_Redis.Server=redis:6379 + - CR_LICENSE_KEY=${CR_LICENSE_KEY} volumes: - - redis_data:/bitnami/redis/data + - backend_data:/cloudreve/data - cloudreve: - container_name: cloudreve - image: cloudreve/cloudreve:latest - restart: unless-stopped - ports: - - "5212:5212" + postgresql: + image: postgres:latest + container_name: postgresql + environment: + - POSTGRES_USER=cloudreve + - POSTGRES_DB=cloudreve + - POSTGRES_HOST_AUTH_METHOD=trust volumes: - - temp_data:/data - - ./cloudreve/uploads:/cloudreve/uploads - - ./cloudreve/conf.ini:/cloudreve/conf.ini - - ./cloudreve/cloudreve.db:/cloudreve/cloudreve.db - - ./cloudreve/avatar:/cloudreve/avatar - depends_on: - - aria2 + - database_postgres:/var/lib/postgresql/data - aria2: - container_name: aria2 - image: p3terx/aria2-pro # third party image, please keep notice what you are doing - restart: unless-stopped - environment: - - RPC_SECRET=your_aria_rpc_token # aria rpc token, customize your own - - RPC_PORT=6800 + redis: + image: redis:latest + container_name: redis volumes: - - ./aria2/config:/config - - temp_data:/data + - backend_data:/data + volumes: - redis_data: - driver: local - temp_data: - driver: local - driver_opts: - type: none - device: $PWD/data - o: bind + backend_data: + database_postgres: diff --git a/ent/client.go b/ent/client.go new file mode 100644 index 00000000..4df84c8a --- /dev/null +++ b/ent/client.go @@ -0,0 +1,2625 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "log" + "reflect" + + "github.com/cloudreve/Cloudreve/v4/ent/migrate" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + + stdsql "database/sql" +) + +// Client is the client that holds all ent builders. +type Client struct { + config + // Schema is the client for creating, migrating and dropping schema. + Schema *migrate.Schema + // DavAccount is the client for interacting with the DavAccount builders. + DavAccount *DavAccountClient + // DirectLink is the client for interacting with the DirectLink builders. + DirectLink *DirectLinkClient + // Entity is the client for interacting with the Entity builders. + Entity *EntityClient + // File is the client for interacting with the File builders. + File *FileClient + // Group is the client for interacting with the Group builders. + Group *GroupClient + // Metadata is the client for interacting with the Metadata builders. + Metadata *MetadataClient + // Node is the client for interacting with the Node builders. + Node *NodeClient + // Passkey is the client for interacting with the Passkey builders. + Passkey *PasskeyClient + // Setting is the client for interacting with the Setting builders. + Setting *SettingClient + // Share is the client for interacting with the Share builders. + Share *ShareClient + // StoragePolicy is the client for interacting with the StoragePolicy builders. + StoragePolicy *StoragePolicyClient + // Task is the client for interacting with the Task builders. + Task *TaskClient + // User is the client for interacting with the User builders. + User *UserClient +} + +// NewClient creates a new client configured with the given options. +func NewClient(opts ...Option) *Client { + client := &Client{config: newConfig(opts...)} + client.init() + return client +} + +func (c *Client) init() { + c.Schema = migrate.NewSchema(c.driver) + c.DavAccount = NewDavAccountClient(c.config) + c.DirectLink = NewDirectLinkClient(c.config) + c.Entity = NewEntityClient(c.config) + c.File = NewFileClient(c.config) + c.Group = NewGroupClient(c.config) + c.Metadata = NewMetadataClient(c.config) + c.Node = NewNodeClient(c.config) + c.Passkey = NewPasskeyClient(c.config) + c.Setting = NewSettingClient(c.config) + c.Share = NewShareClient(c.config) + c.StoragePolicy = NewStoragePolicyClient(c.config) + c.Task = NewTaskClient(c.config) + c.User = NewUserClient(c.config) +} + +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// newConfig creates a new config for the client. +func newConfig(opts ...Option) config { + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.options(opts...) + return cfg +} + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } +} + +// Open opens a database/sql.DB specified by the driver name and +// the data source name, and returns a new client attached to it. +// Optional parameters can be added for configuring the client. +func Open(driverName, dataSourceName string, options ...Option) (*Client, error) { + switch driverName { + case dialect.MySQL, dialect.Postgres, dialect.SQLite: + drv, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return NewClient(append(options, Driver(drv))...), nil + default: + return nil, fmt.Errorf("unsupported driver: %q", driverName) + } +} + +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + +// Tx returns a new transactional client. The provided context +// is used until the transaction is committed or rolled back. +func (c *Client) Tx(ctx context.Context) (*Tx, error) { + if _, ok := c.driver.(*txDriver); ok { + return nil, ErrTxStarted + } + tx, err := newTx(ctx, c.driver) + if err != nil { + return nil, fmt.Errorf("ent: starting a transaction: %w", err) + } + cfg := c.config + cfg.driver = tx + return &Tx{ + ctx: ctx, + config: cfg, + DavAccount: NewDavAccountClient(cfg), + DirectLink: NewDirectLinkClient(cfg), + Entity: NewEntityClient(cfg), + File: NewFileClient(cfg), + Group: NewGroupClient(cfg), + Metadata: NewMetadataClient(cfg), + Node: NewNodeClient(cfg), + Passkey: NewPasskeyClient(cfg), + Setting: NewSettingClient(cfg), + Share: NewShareClient(cfg), + StoragePolicy: NewStoragePolicyClient(cfg), + Task: NewTaskClient(cfg), + User: NewUserClient(cfg), + }, nil +} + +// BeginTx returns a transactional client with specified options. +func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + if _, ok := c.driver.(*txDriver); ok { + return nil, errors.New("ent: cannot start a transaction within a transaction") + } + tx, err := c.driver.(interface { + BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) + }).BeginTx(ctx, opts) + if err != nil { + return nil, fmt.Errorf("ent: starting a transaction: %w", err) + } + cfg := c.config + cfg.driver = &txDriver{tx: tx, drv: c.driver} + return &Tx{ + ctx: ctx, + config: cfg, + DavAccount: NewDavAccountClient(cfg), + DirectLink: NewDirectLinkClient(cfg), + Entity: NewEntityClient(cfg), + File: NewFileClient(cfg), + Group: NewGroupClient(cfg), + Metadata: NewMetadataClient(cfg), + Node: NewNodeClient(cfg), + Passkey: NewPasskeyClient(cfg), + Setting: NewSettingClient(cfg), + Share: NewShareClient(cfg), + StoragePolicy: NewStoragePolicyClient(cfg), + Task: NewTaskClient(cfg), + User: NewUserClient(cfg), + }, nil +} + +// Debug returns a new debug-client. It's used to get verbose logging on specific operations. +// +// client.Debug(). +// DavAccount. +// Query(). +// Count(ctx) +func (c *Client) Debug() *Client { + if c.debug { + return c + } + cfg := c.config + cfg.driver = dialect.Debug(c.driver, c.log) + client := &Client{config: cfg} + client.init() + return client +} + +// Close closes the database connection and prevents new queries from starting. +func (c *Client) Close() error { + return c.driver.Close() +} + +// Use adds the mutation hooks to all the entity clients. +// In order to add hooks to a specific client, call: `client.Node.Use(...)`. +func (c *Client) Use(hooks ...Hook) { + for _, n := range []interface{ Use(...Hook) }{ + c.DavAccount, c.DirectLink, c.Entity, c.File, c.Group, c.Metadata, c.Node, + c.Passkey, c.Setting, c.Share, c.StoragePolicy, c.Task, c.User, + } { + n.Use(hooks...) + } +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + for _, n := range []interface{ Intercept(...Interceptor) }{ + c.DavAccount, c.DirectLink, c.Entity, c.File, c.Group, c.Metadata, c.Node, + c.Passkey, c.Setting, c.Share, c.StoragePolicy, c.Task, c.User, + } { + n.Intercept(interceptors...) + } +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *DavAccountMutation: + return c.DavAccount.mutate(ctx, m) + case *DirectLinkMutation: + return c.DirectLink.mutate(ctx, m) + case *EntityMutation: + return c.Entity.mutate(ctx, m) + case *FileMutation: + return c.File.mutate(ctx, m) + case *GroupMutation: + return c.Group.mutate(ctx, m) + case *MetadataMutation: + return c.Metadata.mutate(ctx, m) + case *NodeMutation: + return c.Node.mutate(ctx, m) + case *PasskeyMutation: + return c.Passkey.mutate(ctx, m) + case *SettingMutation: + return c.Setting.mutate(ctx, m) + case *ShareMutation: + return c.Share.mutate(ctx, m) + case *StoragePolicyMutation: + return c.StoragePolicy.mutate(ctx, m) + case *TaskMutation: + return c.Task.mutate(ctx, m) + case *UserMutation: + return c.User.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } +} + +// DavAccountClient is a client for the DavAccount schema. +type DavAccountClient struct { + config +} + +// NewDavAccountClient returns a client for the DavAccount from the given config. +func NewDavAccountClient(c config) *DavAccountClient { + return &DavAccountClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `davaccount.Hooks(f(g(h())))`. +func (c *DavAccountClient) Use(hooks ...Hook) { + c.hooks.DavAccount = append(c.hooks.DavAccount, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `davaccount.Intercept(f(g(h())))`. +func (c *DavAccountClient) Intercept(interceptors ...Interceptor) { + c.inters.DavAccount = append(c.inters.DavAccount, interceptors...) +} + +// Create returns a builder for creating a DavAccount entity. +func (c *DavAccountClient) Create() *DavAccountCreate { + mutation := newDavAccountMutation(c.config, OpCreate) + return &DavAccountCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of DavAccount entities. +func (c *DavAccountClient) CreateBulk(builders ...*DavAccountCreate) *DavAccountCreateBulk { + return &DavAccountCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DavAccountClient) MapCreateBulk(slice any, setFunc func(*DavAccountCreate, int)) *DavAccountCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DavAccountCreateBulk{err: fmt.Errorf("calling to DavAccountClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DavAccountCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DavAccountCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for DavAccount. +func (c *DavAccountClient) Update() *DavAccountUpdate { + mutation := newDavAccountMutation(c.config, OpUpdate) + return &DavAccountUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *DavAccountClient) UpdateOne(da *DavAccount) *DavAccountUpdateOne { + mutation := newDavAccountMutation(c.config, OpUpdateOne, withDavAccount(da)) + return &DavAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *DavAccountClient) UpdateOneID(id int) *DavAccountUpdateOne { + mutation := newDavAccountMutation(c.config, OpUpdateOne, withDavAccountID(id)) + return &DavAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for DavAccount. +func (c *DavAccountClient) Delete() *DavAccountDelete { + mutation := newDavAccountMutation(c.config, OpDelete) + return &DavAccountDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *DavAccountClient) DeleteOne(da *DavAccount) *DavAccountDeleteOne { + return c.DeleteOneID(da.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *DavAccountClient) DeleteOneID(id int) *DavAccountDeleteOne { + builder := c.Delete().Where(davaccount.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &DavAccountDeleteOne{builder} +} + +// Query returns a query builder for DavAccount. +func (c *DavAccountClient) Query() *DavAccountQuery { + return &DavAccountQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeDavAccount}, + inters: c.Interceptors(), + } +} + +// Get returns a DavAccount entity by its id. +func (c *DavAccountClient) Get(ctx context.Context, id int) (*DavAccount, error) { + return c.Query().Where(davaccount.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *DavAccountClient) GetX(ctx context.Context, id int) *DavAccount { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryOwner queries the owner edge of a DavAccount. +func (c *DavAccountClient) QueryOwner(da *DavAccount) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := da.ID + step := sqlgraph.NewStep( + sqlgraph.From(davaccount.Table, davaccount.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, davaccount.OwnerTable, davaccount.OwnerColumn), + ) + fromV = sqlgraph.Neighbors(da.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *DavAccountClient) Hooks() []Hook { + hooks := c.hooks.DavAccount + return append(hooks[:len(hooks):len(hooks)], davaccount.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *DavAccountClient) Interceptors() []Interceptor { + inters := c.inters.DavAccount + return append(inters[:len(inters):len(inters)], davaccount.Interceptors[:]...) +} + +func (c *DavAccountClient) mutate(ctx context.Context, m *DavAccountMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DavAccountCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DavAccountUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DavAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DavAccountDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown DavAccount mutation op: %q", m.Op()) + } +} + +// DirectLinkClient is a client for the DirectLink schema. +type DirectLinkClient struct { + config +} + +// NewDirectLinkClient returns a client for the DirectLink from the given config. +func NewDirectLinkClient(c config) *DirectLinkClient { + return &DirectLinkClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `directlink.Hooks(f(g(h())))`. +func (c *DirectLinkClient) Use(hooks ...Hook) { + c.hooks.DirectLink = append(c.hooks.DirectLink, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `directlink.Intercept(f(g(h())))`. +func (c *DirectLinkClient) Intercept(interceptors ...Interceptor) { + c.inters.DirectLink = append(c.inters.DirectLink, interceptors...) +} + +// Create returns a builder for creating a DirectLink entity. +func (c *DirectLinkClient) Create() *DirectLinkCreate { + mutation := newDirectLinkMutation(c.config, OpCreate) + return &DirectLinkCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of DirectLink entities. +func (c *DirectLinkClient) CreateBulk(builders ...*DirectLinkCreate) *DirectLinkCreateBulk { + return &DirectLinkCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DirectLinkClient) MapCreateBulk(slice any, setFunc func(*DirectLinkCreate, int)) *DirectLinkCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DirectLinkCreateBulk{err: fmt.Errorf("calling to DirectLinkClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DirectLinkCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DirectLinkCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for DirectLink. +func (c *DirectLinkClient) Update() *DirectLinkUpdate { + mutation := newDirectLinkMutation(c.config, OpUpdate) + return &DirectLinkUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *DirectLinkClient) UpdateOne(dl *DirectLink) *DirectLinkUpdateOne { + mutation := newDirectLinkMutation(c.config, OpUpdateOne, withDirectLink(dl)) + return &DirectLinkUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *DirectLinkClient) UpdateOneID(id int) *DirectLinkUpdateOne { + mutation := newDirectLinkMutation(c.config, OpUpdateOne, withDirectLinkID(id)) + return &DirectLinkUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for DirectLink. +func (c *DirectLinkClient) Delete() *DirectLinkDelete { + mutation := newDirectLinkMutation(c.config, OpDelete) + return &DirectLinkDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *DirectLinkClient) DeleteOne(dl *DirectLink) *DirectLinkDeleteOne { + return c.DeleteOneID(dl.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *DirectLinkClient) DeleteOneID(id int) *DirectLinkDeleteOne { + builder := c.Delete().Where(directlink.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &DirectLinkDeleteOne{builder} +} + +// Query returns a query builder for DirectLink. +func (c *DirectLinkClient) Query() *DirectLinkQuery { + return &DirectLinkQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeDirectLink}, + inters: c.Interceptors(), + } +} + +// Get returns a DirectLink entity by its id. +func (c *DirectLinkClient) Get(ctx context.Context, id int) (*DirectLink, error) { + return c.Query().Where(directlink.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *DirectLinkClient) GetX(ctx context.Context, id int) *DirectLink { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryFile queries the file edge of a DirectLink. +func (c *DirectLinkClient) QueryFile(dl *DirectLink) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := dl.ID + step := sqlgraph.NewStep( + sqlgraph.From(directlink.Table, directlink.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, directlink.FileTable, directlink.FileColumn), + ) + fromV = sqlgraph.Neighbors(dl.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *DirectLinkClient) Hooks() []Hook { + hooks := c.hooks.DirectLink + return append(hooks[:len(hooks):len(hooks)], directlink.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *DirectLinkClient) Interceptors() []Interceptor { + inters := c.inters.DirectLink + return append(inters[:len(inters):len(inters)], directlink.Interceptors[:]...) +} + +func (c *DirectLinkClient) mutate(ctx context.Context, m *DirectLinkMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DirectLinkCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DirectLinkUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DirectLinkUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DirectLinkDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown DirectLink mutation op: %q", m.Op()) + } +} + +// EntityClient is a client for the Entity schema. +type EntityClient struct { + config +} + +// NewEntityClient returns a client for the Entity from the given config. +func NewEntityClient(c config) *EntityClient { + return &EntityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `entity.Hooks(f(g(h())))`. +func (c *EntityClient) Use(hooks ...Hook) { + c.hooks.Entity = append(c.hooks.Entity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `entity.Intercept(f(g(h())))`. +func (c *EntityClient) Intercept(interceptors ...Interceptor) { + c.inters.Entity = append(c.inters.Entity, interceptors...) +} + +// Create returns a builder for creating a Entity entity. +func (c *EntityClient) Create() *EntityCreate { + mutation := newEntityMutation(c.config, OpCreate) + return &EntityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Entity entities. +func (c *EntityClient) CreateBulk(builders ...*EntityCreate) *EntityCreateBulk { + return &EntityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *EntityClient) MapCreateBulk(slice any, setFunc func(*EntityCreate, int)) *EntityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &EntityCreateBulk{err: fmt.Errorf("calling to EntityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*EntityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &EntityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Entity. +func (c *EntityClient) Update() *EntityUpdate { + mutation := newEntityMutation(c.config, OpUpdate) + return &EntityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *EntityClient) UpdateOne(e *Entity) *EntityUpdateOne { + mutation := newEntityMutation(c.config, OpUpdateOne, withEntity(e)) + return &EntityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *EntityClient) UpdateOneID(id int) *EntityUpdateOne { + mutation := newEntityMutation(c.config, OpUpdateOne, withEntityID(id)) + return &EntityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Entity. +func (c *EntityClient) Delete() *EntityDelete { + mutation := newEntityMutation(c.config, OpDelete) + return &EntityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *EntityClient) DeleteOne(e *Entity) *EntityDeleteOne { + return c.DeleteOneID(e.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *EntityClient) DeleteOneID(id int) *EntityDeleteOne { + builder := c.Delete().Where(entity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &EntityDeleteOne{builder} +} + +// Query returns a query builder for Entity. +func (c *EntityClient) Query() *EntityQuery { + return &EntityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeEntity}, + inters: c.Interceptors(), + } +} + +// Get returns a Entity entity by its id. +func (c *EntityClient) Get(ctx context.Context, id int) (*Entity, error) { + return c.Query().Where(entity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *EntityClient) GetX(ctx context.Context, id int) *Entity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryFile queries the file edge of a Entity. +func (c *EntityClient) QueryFile(e *Entity) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := e.ID + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, entity.FileTable, entity.FilePrimaryKey...), + ) + fromV = sqlgraph.Neighbors(e.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryUser queries the user edge of a Entity. +func (c *EntityClient) QueryUser(e *Entity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := e.ID + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, entity.UserTable, entity.UserColumn), + ) + fromV = sqlgraph.Neighbors(e.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryStoragePolicy queries the storage_policy edge of a Entity. +func (c *EntityClient) QueryStoragePolicy(e *Entity) *StoragePolicyQuery { + query := (&StoragePolicyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := e.ID + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, id), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, entity.StoragePolicyTable, entity.StoragePolicyColumn), + ) + fromV = sqlgraph.Neighbors(e.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *EntityClient) Hooks() []Hook { + hooks := c.hooks.Entity + return append(hooks[:len(hooks):len(hooks)], entity.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *EntityClient) Interceptors() []Interceptor { + inters := c.inters.Entity + return append(inters[:len(inters):len(inters)], entity.Interceptors[:]...) +} + +func (c *EntityClient) mutate(ctx context.Context, m *EntityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&EntityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&EntityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&EntityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&EntityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Entity mutation op: %q", m.Op()) + } +} + +// FileClient is a client for the File schema. +type FileClient struct { + config +} + +// NewFileClient returns a client for the File from the given config. +func NewFileClient(c config) *FileClient { + return &FileClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `file.Hooks(f(g(h())))`. +func (c *FileClient) Use(hooks ...Hook) { + c.hooks.File = append(c.hooks.File, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `file.Intercept(f(g(h())))`. +func (c *FileClient) Intercept(interceptors ...Interceptor) { + c.inters.File = append(c.inters.File, interceptors...) +} + +// Create returns a builder for creating a File entity. +func (c *FileClient) Create() *FileCreate { + mutation := newFileMutation(c.config, OpCreate) + return &FileCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of File entities. +func (c *FileClient) CreateBulk(builders ...*FileCreate) *FileCreateBulk { + return &FileCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *FileClient) MapCreateBulk(slice any, setFunc func(*FileCreate, int)) *FileCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &FileCreateBulk{err: fmt.Errorf("calling to FileClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*FileCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &FileCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for File. +func (c *FileClient) Update() *FileUpdate { + mutation := newFileMutation(c.config, OpUpdate) + return &FileUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *FileClient) UpdateOne(f *File) *FileUpdateOne { + mutation := newFileMutation(c.config, OpUpdateOne, withFile(f)) + return &FileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *FileClient) UpdateOneID(id int) *FileUpdateOne { + mutation := newFileMutation(c.config, OpUpdateOne, withFileID(id)) + return &FileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for File. +func (c *FileClient) Delete() *FileDelete { + mutation := newFileMutation(c.config, OpDelete) + return &FileDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *FileClient) DeleteOne(f *File) *FileDeleteOne { + return c.DeleteOneID(f.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *FileClient) DeleteOneID(id int) *FileDeleteOne { + builder := c.Delete().Where(file.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &FileDeleteOne{builder} +} + +// Query returns a query builder for File. +func (c *FileClient) Query() *FileQuery { + return &FileQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeFile}, + inters: c.Interceptors(), + } +} + +// Get returns a File entity by its id. +func (c *FileClient) Get(ctx context.Context, id int) (*File, error) { + return c.Query().Where(file.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *FileClient) GetX(ctx context.Context, id int) *File { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryOwner queries the owner edge of a File. +func (c *FileClient) QueryOwner(f *File) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.OwnerTable, file.OwnerColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryStoragePolicies queries the storage_policies edge of a File. +func (c *FileClient) QueryStoragePolicies(f *File) *StoragePolicyQuery { + query := (&StoragePolicyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.StoragePoliciesTable, file.StoragePoliciesColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryParent queries the parent edge of a File. +func (c *FileClient) QueryParent(f *File) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.ParentTable, file.ParentColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChildren queries the children edge of a File. +func (c *FileClient) QueryChildren(f *File) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.ChildrenTable, file.ChildrenColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryMetadata queries the metadata edge of a File. +func (c *FileClient) QueryMetadata(f *File) *MetadataQuery { + query := (&MetadataClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(metadata.Table, metadata.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.MetadataTable, file.MetadataColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryEntities queries the entities edge of a File. +func (c *FileClient) QueryEntities(f *File) *EntityQuery { + query := (&EntityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, file.EntitiesTable, file.EntitiesPrimaryKey...), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryShares queries the shares edge of a File. +func (c *FileClient) QueryShares(f *File) *ShareQuery { + query := (&ShareClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(share.Table, share.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.SharesTable, file.SharesColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryDirectLinks queries the direct_links edge of a File. +func (c *FileClient) QueryDirectLinks(f *File) *DirectLinkQuery { + query := (&DirectLinkClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := f.ID + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, id), + sqlgraph.To(directlink.Table, directlink.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.DirectLinksTable, file.DirectLinksColumn), + ) + fromV = sqlgraph.Neighbors(f.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *FileClient) Hooks() []Hook { + hooks := c.hooks.File + return append(hooks[:len(hooks):len(hooks)], file.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *FileClient) Interceptors() []Interceptor { + inters := c.inters.File + return append(inters[:len(inters):len(inters)], file.Interceptors[:]...) +} + +func (c *FileClient) mutate(ctx context.Context, m *FileMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&FileCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&FileUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&FileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&FileDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown File mutation op: %q", m.Op()) + } +} + +// GroupClient is a client for the Group schema. +type GroupClient struct { + config +} + +// NewGroupClient returns a client for the Group from the given config. +func NewGroupClient(c config) *GroupClient { + return &GroupClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `group.Hooks(f(g(h())))`. +func (c *GroupClient) Use(hooks ...Hook) { + c.hooks.Group = append(c.hooks.Group, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `group.Intercept(f(g(h())))`. +func (c *GroupClient) Intercept(interceptors ...Interceptor) { + c.inters.Group = append(c.inters.Group, interceptors...) +} + +// Create returns a builder for creating a Group entity. +func (c *GroupClient) Create() *GroupCreate { + mutation := newGroupMutation(c.config, OpCreate) + return &GroupCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Group entities. +func (c *GroupClient) CreateBulk(builders ...*GroupCreate) *GroupCreateBulk { + return &GroupCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *GroupClient) MapCreateBulk(slice any, setFunc func(*GroupCreate, int)) *GroupCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &GroupCreateBulk{err: fmt.Errorf("calling to GroupClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*GroupCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &GroupCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Group. +func (c *GroupClient) Update() *GroupUpdate { + mutation := newGroupMutation(c.config, OpUpdate) + return &GroupUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *GroupClient) UpdateOne(gr *Group) *GroupUpdateOne { + mutation := newGroupMutation(c.config, OpUpdateOne, withGroup(gr)) + return &GroupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *GroupClient) UpdateOneID(id int) *GroupUpdateOne { + mutation := newGroupMutation(c.config, OpUpdateOne, withGroupID(id)) + return &GroupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Group. +func (c *GroupClient) Delete() *GroupDelete { + mutation := newGroupMutation(c.config, OpDelete) + return &GroupDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *GroupClient) DeleteOne(gr *Group) *GroupDeleteOne { + return c.DeleteOneID(gr.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *GroupClient) DeleteOneID(id int) *GroupDeleteOne { + builder := c.Delete().Where(group.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &GroupDeleteOne{builder} +} + +// Query returns a query builder for Group. +func (c *GroupClient) Query() *GroupQuery { + return &GroupQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeGroup}, + inters: c.Interceptors(), + } +} + +// Get returns a Group entity by its id. +func (c *GroupClient) Get(ctx context.Context, id int) (*Group, error) { + return c.Query().Where(group.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *GroupClient) GetX(ctx context.Context, id int) *Group { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUsers queries the users edge of a Group. +func (c *GroupClient) QueryUsers(gr *Group) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := gr.ID + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsersTable, group.UsersColumn), + ) + fromV = sqlgraph.Neighbors(gr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryStoragePolicies queries the storage_policies edge of a Group. +func (c *GroupClient) QueryStoragePolicies(gr *Group) *StoragePolicyQuery { + query := (&StoragePolicyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := gr.ID + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, id), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, group.StoragePoliciesTable, group.StoragePoliciesColumn), + ) + fromV = sqlgraph.Neighbors(gr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *GroupClient) Hooks() []Hook { + hooks := c.hooks.Group + return append(hooks[:len(hooks):len(hooks)], group.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *GroupClient) Interceptors() []Interceptor { + inters := c.inters.Group + return append(inters[:len(inters):len(inters)], group.Interceptors[:]...) +} + +func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&GroupCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&GroupUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&GroupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&GroupDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Group mutation op: %q", m.Op()) + } +} + +// MetadataClient is a client for the Metadata schema. +type MetadataClient struct { + config +} + +// NewMetadataClient returns a client for the Metadata from the given config. +func NewMetadataClient(c config) *MetadataClient { + return &MetadataClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `metadata.Hooks(f(g(h())))`. +func (c *MetadataClient) Use(hooks ...Hook) { + c.hooks.Metadata = append(c.hooks.Metadata, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `metadata.Intercept(f(g(h())))`. +func (c *MetadataClient) Intercept(interceptors ...Interceptor) { + c.inters.Metadata = append(c.inters.Metadata, interceptors...) +} + +// Create returns a builder for creating a Metadata entity. +func (c *MetadataClient) Create() *MetadataCreate { + mutation := newMetadataMutation(c.config, OpCreate) + return &MetadataCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Metadata entities. +func (c *MetadataClient) CreateBulk(builders ...*MetadataCreate) *MetadataCreateBulk { + return &MetadataCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MetadataClient) MapCreateBulk(slice any, setFunc func(*MetadataCreate, int)) *MetadataCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MetadataCreateBulk{err: fmt.Errorf("calling to MetadataClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MetadataCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MetadataCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Metadata. +func (c *MetadataClient) Update() *MetadataUpdate { + mutation := newMetadataMutation(c.config, OpUpdate) + return &MetadataUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *MetadataClient) UpdateOne(m *Metadata) *MetadataUpdateOne { + mutation := newMetadataMutation(c.config, OpUpdateOne, withMetadata(m)) + return &MetadataUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *MetadataClient) UpdateOneID(id int) *MetadataUpdateOne { + mutation := newMetadataMutation(c.config, OpUpdateOne, withMetadataID(id)) + return &MetadataUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Metadata. +func (c *MetadataClient) Delete() *MetadataDelete { + mutation := newMetadataMutation(c.config, OpDelete) + return &MetadataDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *MetadataClient) DeleteOne(m *Metadata) *MetadataDeleteOne { + return c.DeleteOneID(m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *MetadataClient) DeleteOneID(id int) *MetadataDeleteOne { + builder := c.Delete().Where(metadata.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &MetadataDeleteOne{builder} +} + +// Query returns a query builder for Metadata. +func (c *MetadataClient) Query() *MetadataQuery { + return &MetadataQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeMetadata}, + inters: c.Interceptors(), + } +} + +// Get returns a Metadata entity by its id. +func (c *MetadataClient) Get(ctx context.Context, id int) (*Metadata, error) { + return c.Query().Where(metadata.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *MetadataClient) GetX(ctx context.Context, id int) *Metadata { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryFile queries the file edge of a Metadata. +func (c *MetadataClient) QueryFile(m *Metadata) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := m.ID + step := sqlgraph.NewStep( + sqlgraph.From(metadata.Table, metadata.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, metadata.FileTable, metadata.FileColumn), + ) + fromV = sqlgraph.Neighbors(m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *MetadataClient) Hooks() []Hook { + hooks := c.hooks.Metadata + return append(hooks[:len(hooks):len(hooks)], metadata.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *MetadataClient) Interceptors() []Interceptor { + inters := c.inters.Metadata + return append(inters[:len(inters):len(inters)], metadata.Interceptors[:]...) +} + +func (c *MetadataClient) mutate(ctx context.Context, m *MetadataMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MetadataCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MetadataUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MetadataUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MetadataDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Metadata mutation op: %q", m.Op()) + } +} + +// NodeClient is a client for the Node schema. +type NodeClient struct { + config +} + +// NewNodeClient returns a client for the Node from the given config. +func NewNodeClient(c config) *NodeClient { + return &NodeClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `node.Hooks(f(g(h())))`. +func (c *NodeClient) Use(hooks ...Hook) { + c.hooks.Node = append(c.hooks.Node, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `node.Intercept(f(g(h())))`. +func (c *NodeClient) Intercept(interceptors ...Interceptor) { + c.inters.Node = append(c.inters.Node, interceptors...) +} + +// Create returns a builder for creating a Node entity. +func (c *NodeClient) Create() *NodeCreate { + mutation := newNodeMutation(c.config, OpCreate) + return &NodeCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Node entities. +func (c *NodeClient) CreateBulk(builders ...*NodeCreate) *NodeCreateBulk { + return &NodeCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *NodeClient) MapCreateBulk(slice any, setFunc func(*NodeCreate, int)) *NodeCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &NodeCreateBulk{err: fmt.Errorf("calling to NodeClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*NodeCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &NodeCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Node. +func (c *NodeClient) Update() *NodeUpdate { + mutation := newNodeMutation(c.config, OpUpdate) + return &NodeUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *NodeClient) UpdateOne(n *Node) *NodeUpdateOne { + mutation := newNodeMutation(c.config, OpUpdateOne, withNode(n)) + return &NodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *NodeClient) UpdateOneID(id int) *NodeUpdateOne { + mutation := newNodeMutation(c.config, OpUpdateOne, withNodeID(id)) + return &NodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Node. +func (c *NodeClient) Delete() *NodeDelete { + mutation := newNodeMutation(c.config, OpDelete) + return &NodeDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *NodeClient) DeleteOne(n *Node) *NodeDeleteOne { + return c.DeleteOneID(n.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *NodeClient) DeleteOneID(id int) *NodeDeleteOne { + builder := c.Delete().Where(node.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &NodeDeleteOne{builder} +} + +// Query returns a query builder for Node. +func (c *NodeClient) Query() *NodeQuery { + return &NodeQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeNode}, + inters: c.Interceptors(), + } +} + +// Get returns a Node entity by its id. +func (c *NodeClient) Get(ctx context.Context, id int) (*Node, error) { + return c.Query().Where(node.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *NodeClient) GetX(ctx context.Context, id int) *Node { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryStoragePolicy queries the storage_policy edge of a Node. +func (c *NodeClient) QueryStoragePolicy(n *Node) *StoragePolicyQuery { + query := (&StoragePolicyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := n.ID + step := sqlgraph.NewStep( + sqlgraph.From(node.Table, node.FieldID, id), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, node.StoragePolicyTable, node.StoragePolicyColumn), + ) + fromV = sqlgraph.Neighbors(n.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *NodeClient) Hooks() []Hook { + hooks := c.hooks.Node + return append(hooks[:len(hooks):len(hooks)], node.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *NodeClient) Interceptors() []Interceptor { + inters := c.inters.Node + return append(inters[:len(inters):len(inters)], node.Interceptors[:]...) +} + +func (c *NodeClient) mutate(ctx context.Context, m *NodeMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&NodeCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&NodeUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&NodeUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&NodeDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Node mutation op: %q", m.Op()) + } +} + +// PasskeyClient is a client for the Passkey schema. +type PasskeyClient struct { + config +} + +// NewPasskeyClient returns a client for the Passkey from the given config. +func NewPasskeyClient(c config) *PasskeyClient { + return &PasskeyClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `passkey.Hooks(f(g(h())))`. +func (c *PasskeyClient) Use(hooks ...Hook) { + c.hooks.Passkey = append(c.hooks.Passkey, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `passkey.Intercept(f(g(h())))`. +func (c *PasskeyClient) Intercept(interceptors ...Interceptor) { + c.inters.Passkey = append(c.inters.Passkey, interceptors...) +} + +// Create returns a builder for creating a Passkey entity. +func (c *PasskeyClient) Create() *PasskeyCreate { + mutation := newPasskeyMutation(c.config, OpCreate) + return &PasskeyCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Passkey entities. +func (c *PasskeyClient) CreateBulk(builders ...*PasskeyCreate) *PasskeyCreateBulk { + return &PasskeyCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PasskeyClient) MapCreateBulk(slice any, setFunc func(*PasskeyCreate, int)) *PasskeyCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PasskeyCreateBulk{err: fmt.Errorf("calling to PasskeyClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PasskeyCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PasskeyCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Passkey. +func (c *PasskeyClient) Update() *PasskeyUpdate { + mutation := newPasskeyMutation(c.config, OpUpdate) + return &PasskeyUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PasskeyClient) UpdateOne(pa *Passkey) *PasskeyUpdateOne { + mutation := newPasskeyMutation(c.config, OpUpdateOne, withPasskey(pa)) + return &PasskeyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PasskeyClient) UpdateOneID(id int) *PasskeyUpdateOne { + mutation := newPasskeyMutation(c.config, OpUpdateOne, withPasskeyID(id)) + return &PasskeyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Passkey. +func (c *PasskeyClient) Delete() *PasskeyDelete { + mutation := newPasskeyMutation(c.config, OpDelete) + return &PasskeyDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PasskeyClient) DeleteOne(pa *Passkey) *PasskeyDeleteOne { + return c.DeleteOneID(pa.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PasskeyClient) DeleteOneID(id int) *PasskeyDeleteOne { + builder := c.Delete().Where(passkey.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PasskeyDeleteOne{builder} +} + +// Query returns a query builder for Passkey. +func (c *PasskeyClient) Query() *PasskeyQuery { + return &PasskeyQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePasskey}, + inters: c.Interceptors(), + } +} + +// Get returns a Passkey entity by its id. +func (c *PasskeyClient) Get(ctx context.Context, id int) (*Passkey, error) { + return c.Query().Where(passkey.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PasskeyClient) GetX(ctx context.Context, id int) *Passkey { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a Passkey. +func (c *PasskeyClient) QueryUser(pa *Passkey) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := pa.ID + step := sqlgraph.NewStep( + sqlgraph.From(passkey.Table, passkey.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, passkey.UserTable, passkey.UserColumn), + ) + fromV = sqlgraph.Neighbors(pa.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PasskeyClient) Hooks() []Hook { + hooks := c.hooks.Passkey + return append(hooks[:len(hooks):len(hooks)], passkey.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *PasskeyClient) Interceptors() []Interceptor { + inters := c.inters.Passkey + return append(inters[:len(inters):len(inters)], passkey.Interceptors[:]...) +} + +func (c *PasskeyClient) mutate(ctx context.Context, m *PasskeyMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PasskeyCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PasskeyUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PasskeyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PasskeyDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Passkey mutation op: %q", m.Op()) + } +} + +// SettingClient is a client for the Setting schema. +type SettingClient struct { + config +} + +// NewSettingClient returns a client for the Setting from the given config. +func NewSettingClient(c config) *SettingClient { + return &SettingClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `setting.Hooks(f(g(h())))`. +func (c *SettingClient) Use(hooks ...Hook) { + c.hooks.Setting = append(c.hooks.Setting, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `setting.Intercept(f(g(h())))`. +func (c *SettingClient) Intercept(interceptors ...Interceptor) { + c.inters.Setting = append(c.inters.Setting, interceptors...) +} + +// Create returns a builder for creating a Setting entity. +func (c *SettingClient) Create() *SettingCreate { + mutation := newSettingMutation(c.config, OpCreate) + return &SettingCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Setting entities. +func (c *SettingClient) CreateBulk(builders ...*SettingCreate) *SettingCreateBulk { + return &SettingCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SettingClient) MapCreateBulk(slice any, setFunc func(*SettingCreate, int)) *SettingCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SettingCreateBulk{err: fmt.Errorf("calling to SettingClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SettingCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SettingCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Setting. +func (c *SettingClient) Update() *SettingUpdate { + mutation := newSettingMutation(c.config, OpUpdate) + return &SettingUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SettingClient) UpdateOne(s *Setting) *SettingUpdateOne { + mutation := newSettingMutation(c.config, OpUpdateOne, withSetting(s)) + return &SettingUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SettingClient) UpdateOneID(id int) *SettingUpdateOne { + mutation := newSettingMutation(c.config, OpUpdateOne, withSettingID(id)) + return &SettingUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Setting. +func (c *SettingClient) Delete() *SettingDelete { + mutation := newSettingMutation(c.config, OpDelete) + return &SettingDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SettingClient) DeleteOne(s *Setting) *SettingDeleteOne { + return c.DeleteOneID(s.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SettingClient) DeleteOneID(id int) *SettingDeleteOne { + builder := c.Delete().Where(setting.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SettingDeleteOne{builder} +} + +// Query returns a query builder for Setting. +func (c *SettingClient) Query() *SettingQuery { + return &SettingQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSetting}, + inters: c.Interceptors(), + } +} + +// Get returns a Setting entity by its id. +func (c *SettingClient) Get(ctx context.Context, id int) (*Setting, error) { + return c.Query().Where(setting.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SettingClient) GetX(ctx context.Context, id int) *Setting { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SettingClient) Hooks() []Hook { + hooks := c.hooks.Setting + return append(hooks[:len(hooks):len(hooks)], setting.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *SettingClient) Interceptors() []Interceptor { + inters := c.inters.Setting + return append(inters[:len(inters):len(inters)], setting.Interceptors[:]...) +} + +func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SettingCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SettingUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SettingUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SettingDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Setting mutation op: %q", m.Op()) + } +} + +// ShareClient is a client for the Share schema. +type ShareClient struct { + config +} + +// NewShareClient returns a client for the Share from the given config. +func NewShareClient(c config) *ShareClient { + return &ShareClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `share.Hooks(f(g(h())))`. +func (c *ShareClient) Use(hooks ...Hook) { + c.hooks.Share = append(c.hooks.Share, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `share.Intercept(f(g(h())))`. +func (c *ShareClient) Intercept(interceptors ...Interceptor) { + c.inters.Share = append(c.inters.Share, interceptors...) +} + +// Create returns a builder for creating a Share entity. +func (c *ShareClient) Create() *ShareCreate { + mutation := newShareMutation(c.config, OpCreate) + return &ShareCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Share entities. +func (c *ShareClient) CreateBulk(builders ...*ShareCreate) *ShareCreateBulk { + return &ShareCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ShareClient) MapCreateBulk(slice any, setFunc func(*ShareCreate, int)) *ShareCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ShareCreateBulk{err: fmt.Errorf("calling to ShareClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ShareCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ShareCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Share. +func (c *ShareClient) Update() *ShareUpdate { + mutation := newShareMutation(c.config, OpUpdate) + return &ShareUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ShareClient) UpdateOne(s *Share) *ShareUpdateOne { + mutation := newShareMutation(c.config, OpUpdateOne, withShare(s)) + return &ShareUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ShareClient) UpdateOneID(id int) *ShareUpdateOne { + mutation := newShareMutation(c.config, OpUpdateOne, withShareID(id)) + return &ShareUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Share. +func (c *ShareClient) Delete() *ShareDelete { + mutation := newShareMutation(c.config, OpDelete) + return &ShareDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ShareClient) DeleteOne(s *Share) *ShareDeleteOne { + return c.DeleteOneID(s.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ShareClient) DeleteOneID(id int) *ShareDeleteOne { + builder := c.Delete().Where(share.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ShareDeleteOne{builder} +} + +// Query returns a query builder for Share. +func (c *ShareClient) Query() *ShareQuery { + return &ShareQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeShare}, + inters: c.Interceptors(), + } +} + +// Get returns a Share entity by its id. +func (c *ShareClient) Get(ctx context.Context, id int) (*Share, error) { + return c.Query().Where(share.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ShareClient) GetX(ctx context.Context, id int) *Share { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a Share. +func (c *ShareClient) QueryUser(s *Share) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := s.ID + step := sqlgraph.NewStep( + sqlgraph.From(share.Table, share.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, share.UserTable, share.UserColumn), + ) + fromV = sqlgraph.Neighbors(s.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryFile queries the file edge of a Share. +func (c *ShareClient) QueryFile(s *Share) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := s.ID + step := sqlgraph.NewStep( + sqlgraph.From(share.Table, share.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, share.FileTable, share.FileColumn), + ) + fromV = sqlgraph.Neighbors(s.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *ShareClient) Hooks() []Hook { + hooks := c.hooks.Share + return append(hooks[:len(hooks):len(hooks)], share.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *ShareClient) Interceptors() []Interceptor { + inters := c.inters.Share + return append(inters[:len(inters):len(inters)], share.Interceptors[:]...) +} + +func (c *ShareClient) mutate(ctx context.Context, m *ShareMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ShareCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ShareUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ShareUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ShareDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Share mutation op: %q", m.Op()) + } +} + +// StoragePolicyClient is a client for the StoragePolicy schema. +type StoragePolicyClient struct { + config +} + +// NewStoragePolicyClient returns a client for the StoragePolicy from the given config. +func NewStoragePolicyClient(c config) *StoragePolicyClient { + return &StoragePolicyClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `storagepolicy.Hooks(f(g(h())))`. +func (c *StoragePolicyClient) Use(hooks ...Hook) { + c.hooks.StoragePolicy = append(c.hooks.StoragePolicy, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `storagepolicy.Intercept(f(g(h())))`. +func (c *StoragePolicyClient) Intercept(interceptors ...Interceptor) { + c.inters.StoragePolicy = append(c.inters.StoragePolicy, interceptors...) +} + +// Create returns a builder for creating a StoragePolicy entity. +func (c *StoragePolicyClient) Create() *StoragePolicyCreate { + mutation := newStoragePolicyMutation(c.config, OpCreate) + return &StoragePolicyCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of StoragePolicy entities. +func (c *StoragePolicyClient) CreateBulk(builders ...*StoragePolicyCreate) *StoragePolicyCreateBulk { + return &StoragePolicyCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *StoragePolicyClient) MapCreateBulk(slice any, setFunc func(*StoragePolicyCreate, int)) *StoragePolicyCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &StoragePolicyCreateBulk{err: fmt.Errorf("calling to StoragePolicyClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*StoragePolicyCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &StoragePolicyCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for StoragePolicy. +func (c *StoragePolicyClient) Update() *StoragePolicyUpdate { + mutation := newStoragePolicyMutation(c.config, OpUpdate) + return &StoragePolicyUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *StoragePolicyClient) UpdateOne(sp *StoragePolicy) *StoragePolicyUpdateOne { + mutation := newStoragePolicyMutation(c.config, OpUpdateOne, withStoragePolicy(sp)) + return &StoragePolicyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *StoragePolicyClient) UpdateOneID(id int) *StoragePolicyUpdateOne { + mutation := newStoragePolicyMutation(c.config, OpUpdateOne, withStoragePolicyID(id)) + return &StoragePolicyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for StoragePolicy. +func (c *StoragePolicyClient) Delete() *StoragePolicyDelete { + mutation := newStoragePolicyMutation(c.config, OpDelete) + return &StoragePolicyDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *StoragePolicyClient) DeleteOne(sp *StoragePolicy) *StoragePolicyDeleteOne { + return c.DeleteOneID(sp.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *StoragePolicyClient) DeleteOneID(id int) *StoragePolicyDeleteOne { + builder := c.Delete().Where(storagepolicy.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &StoragePolicyDeleteOne{builder} +} + +// Query returns a query builder for StoragePolicy. +func (c *StoragePolicyClient) Query() *StoragePolicyQuery { + return &StoragePolicyQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeStoragePolicy}, + inters: c.Interceptors(), + } +} + +// Get returns a StoragePolicy entity by its id. +func (c *StoragePolicyClient) Get(ctx context.Context, id int) (*StoragePolicy, error) { + return c.Query().Where(storagepolicy.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *StoragePolicyClient) GetX(ctx context.Context, id int) *StoragePolicy { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUsers queries the users edge of a StoragePolicy. +func (c *StoragePolicyClient) QueryUsers(sp *StoragePolicy) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := sp.ID + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.UsersTable, storagepolicy.UsersColumn), + ) + fromV = sqlgraph.Neighbors(sp.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryGroups queries the groups edge of a StoragePolicy. +func (c *StoragePolicyClient) QueryGroups(sp *StoragePolicy) *GroupQuery { + query := (&GroupClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := sp.ID + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, id), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.GroupsTable, storagepolicy.GroupsColumn), + ) + fromV = sqlgraph.Neighbors(sp.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryFiles queries the files edge of a StoragePolicy. +func (c *StoragePolicyClient) QueryFiles(sp *StoragePolicy) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := sp.ID + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.FilesTable, storagepolicy.FilesColumn), + ) + fromV = sqlgraph.Neighbors(sp.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryEntities queries the entities edge of a StoragePolicy. +func (c *StoragePolicyClient) QueryEntities(sp *StoragePolicy) *EntityQuery { + query := (&EntityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := sp.ID + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, id), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.EntitiesTable, storagepolicy.EntitiesColumn), + ) + fromV = sqlgraph.Neighbors(sp.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryNode queries the node edge of a StoragePolicy. +func (c *StoragePolicyClient) QueryNode(sp *StoragePolicy) *NodeQuery { + query := (&NodeClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := sp.ID + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, id), + sqlgraph.To(node.Table, node.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, storagepolicy.NodeTable, storagepolicy.NodeColumn), + ) + fromV = sqlgraph.Neighbors(sp.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *StoragePolicyClient) Hooks() []Hook { + hooks := c.hooks.StoragePolicy + return append(hooks[:len(hooks):len(hooks)], storagepolicy.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *StoragePolicyClient) Interceptors() []Interceptor { + inters := c.inters.StoragePolicy + return append(inters[:len(inters):len(inters)], storagepolicy.Interceptors[:]...) +} + +func (c *StoragePolicyClient) mutate(ctx context.Context, m *StoragePolicyMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&StoragePolicyCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&StoragePolicyUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&StoragePolicyUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&StoragePolicyDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown StoragePolicy mutation op: %q", m.Op()) + } +} + +// TaskClient is a client for the Task schema. +type TaskClient struct { + config +} + +// NewTaskClient returns a client for the Task from the given config. +func NewTaskClient(c config) *TaskClient { + return &TaskClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `task.Hooks(f(g(h())))`. +func (c *TaskClient) Use(hooks ...Hook) { + c.hooks.Task = append(c.hooks.Task, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `task.Intercept(f(g(h())))`. +func (c *TaskClient) Intercept(interceptors ...Interceptor) { + c.inters.Task = append(c.inters.Task, interceptors...) +} + +// Create returns a builder for creating a Task entity. +func (c *TaskClient) Create() *TaskCreate { + mutation := newTaskMutation(c.config, OpCreate) + return &TaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Task entities. +func (c *TaskClient) CreateBulk(builders ...*TaskCreate) *TaskCreateBulk { + return &TaskCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *TaskClient) MapCreateBulk(slice any, setFunc func(*TaskCreate, int)) *TaskCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &TaskCreateBulk{err: fmt.Errorf("calling to TaskClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*TaskCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &TaskCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Task. +func (c *TaskClient) Update() *TaskUpdate { + mutation := newTaskMutation(c.config, OpUpdate) + return &TaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *TaskClient) UpdateOne(t *Task) *TaskUpdateOne { + mutation := newTaskMutation(c.config, OpUpdateOne, withTask(t)) + return &TaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *TaskClient) UpdateOneID(id int) *TaskUpdateOne { + mutation := newTaskMutation(c.config, OpUpdateOne, withTaskID(id)) + return &TaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Task. +func (c *TaskClient) Delete() *TaskDelete { + mutation := newTaskMutation(c.config, OpDelete) + return &TaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *TaskClient) DeleteOne(t *Task) *TaskDeleteOne { + return c.DeleteOneID(t.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *TaskClient) DeleteOneID(id int) *TaskDeleteOne { + builder := c.Delete().Where(task.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &TaskDeleteOne{builder} +} + +// Query returns a query builder for Task. +func (c *TaskClient) Query() *TaskQuery { + return &TaskQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeTask}, + inters: c.Interceptors(), + } +} + +// Get returns a Task entity by its id. +func (c *TaskClient) Get(ctx context.Context, id int) (*Task, error) { + return c.Query().Where(task.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *TaskClient) GetX(ctx context.Context, id int) *Task { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a Task. +func (c *TaskClient) QueryUser(t *Task) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := t.ID + step := sqlgraph.NewStep( + sqlgraph.From(task.Table, task.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, task.UserTable, task.UserColumn), + ) + fromV = sqlgraph.Neighbors(t.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *TaskClient) Hooks() []Hook { + hooks := c.hooks.Task + return append(hooks[:len(hooks):len(hooks)], task.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *TaskClient) Interceptors() []Interceptor { + inters := c.inters.Task + return append(inters[:len(inters):len(inters)], task.Interceptors[:]...) +} + +func (c *TaskClient) mutate(ctx context.Context, m *TaskMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&TaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&TaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&TaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&TaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Task mutation op: %q", m.Op()) + } +} + +// UserClient is a client for the User schema. +type UserClient struct { + config +} + +// NewUserClient returns a client for the User from the given config. +func NewUserClient(c config) *UserClient { + return &UserClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `user.Hooks(f(g(h())))`. +func (c *UserClient) Use(hooks ...Hook) { + c.hooks.User = append(c.hooks.User, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `user.Intercept(f(g(h())))`. +func (c *UserClient) Intercept(interceptors ...Interceptor) { + c.inters.User = append(c.inters.User, interceptors...) +} + +// Create returns a builder for creating a User entity. +func (c *UserClient) Create() *UserCreate { + mutation := newUserMutation(c.config, OpCreate) + return &UserCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of User entities. +func (c *UserClient) CreateBulk(builders ...*UserCreate) *UserCreateBulk { + return &UserCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UserClient) MapCreateBulk(slice any, setFunc func(*UserCreate, int)) *UserCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UserCreateBulk{err: fmt.Errorf("calling to UserClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UserCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UserCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for User. +func (c *UserClient) Update() *UserUpdate { + mutation := newUserMutation(c.config, OpUpdate) + return &UserUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UserClient) UpdateOne(u *User) *UserUpdateOne { + mutation := newUserMutation(c.config, OpUpdateOne, withUser(u)) + return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UserClient) UpdateOneID(id int) *UserUpdateOne { + mutation := newUserMutation(c.config, OpUpdateOne, withUserID(id)) + return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for User. +func (c *UserClient) Delete() *UserDelete { + mutation := newUserMutation(c.config, OpDelete) + return &UserDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UserClient) DeleteOne(u *User) *UserDeleteOne { + return c.DeleteOneID(u.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { + builder := c.Delete().Where(user.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UserDeleteOne{builder} +} + +// Query returns a query builder for User. +func (c *UserClient) Query() *UserQuery { + return &UserQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUser}, + inters: c.Interceptors(), + } +} + +// Get returns a User entity by its id. +func (c *UserClient) Get(ctx context.Context, id int) (*User, error) { + return c.Query().Where(user.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UserClient) GetX(ctx context.Context, id int) *User { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryGroup queries the group edge of a User. +func (c *UserClient) QueryGroup(u *User) *GroupQuery { + query := (&GroupClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, user.GroupTable, user.GroupColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryFiles queries the files edge of a User. +func (c *UserClient) QueryFiles(u *User) *FileQuery { + query := (&FileClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.FilesTable, user.FilesColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryDavAccounts queries the dav_accounts edge of a User. +func (c *UserClient) QueryDavAccounts(u *User) *DavAccountQuery { + query := (&DavAccountClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(davaccount.Table, davaccount.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.DavAccountsTable, user.DavAccountsColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryShares queries the shares edge of a User. +func (c *UserClient) QueryShares(u *User) *ShareQuery { + query := (&ShareClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(share.Table, share.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.SharesTable, user.SharesColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPasskey queries the passkey edge of a User. +func (c *UserClient) QueryPasskey(u *User) *PasskeyQuery { + query := (&PasskeyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(passkey.Table, passkey.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PasskeyTable, user.PasskeyColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryTasks queries the tasks edge of a User. +func (c *UserClient) QueryTasks(u *User) *TaskQuery { + query := (&TaskClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(task.Table, task.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.TasksTable, user.TasksColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryEntities queries the entities edge of a User. +func (c *UserClient) QueryEntities(u *User) *EntityQuery { + query := (&EntityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.EntitiesTable, user.EntitiesColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *UserClient) Hooks() []Hook { + hooks := c.hooks.User + return append(hooks[:len(hooks):len(hooks)], user.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *UserClient) Interceptors() []Interceptor { + inters := c.inters.User + return append(inters[:len(inters):len(inters)], user.Interceptors[:]...) +} + +func (c *UserClient) mutate(ctx context.Context, m *UserMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UserCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UserUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UserDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown User mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + DavAccount, DirectLink, Entity, File, Group, Metadata, Node, Passkey, Setting, + Share, StoragePolicy, Task, User []ent.Hook + } + inters struct { + DavAccount, DirectLink, Entity, File, Group, Metadata, Node, Passkey, Setting, + Share, StoragePolicy, Task, User []ent.Interceptor + } +) + +// ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it. +// See, database/sql#DB.ExecContext for more information. +func (c *config) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { + ex, ok := c.driver.(interface { + ExecContext(context.Context, string, ...any) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) +} + +// QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it. +// See, database/sql#DB.QueryContext for more information. +func (c *config) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { + q, ok := c.driver.(interface { + QueryContext(context.Context, string, ...any) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) +} diff --git a/ent/davaccount.go b/ent/davaccount.go new file mode 100644 index 00000000..294b4bcc --- /dev/null +++ b/ent/davaccount.go @@ -0,0 +1,242 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// DavAccount is the model entity for the DavAccount schema. +type DavAccount struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // URI holds the value of the "uri" field. + URI string `json:"uri,omitempty"` + // Password holds the value of the "password" field. + Password string `json:"-"` + // Options holds the value of the "options" field. + Options *boolset.BooleanSet `json:"options,omitempty"` + // Props holds the value of the "props" field. + Props *types.DavAccountProps `json:"props,omitempty"` + // OwnerID holds the value of the "owner_id" field. + OwnerID int `json:"owner_id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the DavAccountQuery when eager-loading is set. + Edges DavAccountEdges `json:"edges"` + selectValues sql.SelectValues +} + +// DavAccountEdges holds the relations/edges for other nodes in the graph. +type DavAccountEdges struct { + // Owner holds the value of the owner edge. + Owner *User `json:"owner,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// OwnerOrErr returns the Owner value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e DavAccountEdges) OwnerOrErr() (*User, error) { + if e.loadedTypes[0] { + if e.Owner == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.Owner, nil + } + return nil, &NotLoadedError{edge: "owner"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*DavAccount) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case davaccount.FieldProps: + values[i] = new([]byte) + case davaccount.FieldOptions: + values[i] = new(boolset.BooleanSet) + case davaccount.FieldID, davaccount.FieldOwnerID: + values[i] = new(sql.NullInt64) + case davaccount.FieldName, davaccount.FieldURI, davaccount.FieldPassword: + values[i] = new(sql.NullString) + case davaccount.FieldCreatedAt, davaccount.FieldUpdatedAt, davaccount.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the DavAccount fields. +func (da *DavAccount) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case davaccount.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + da.ID = int(value.Int64) + case davaccount.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + da.CreatedAt = value.Time + } + case davaccount.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + da.UpdatedAt = value.Time + } + case davaccount.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + da.DeletedAt = new(time.Time) + *da.DeletedAt = value.Time + } + case davaccount.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + da.Name = value.String + } + case davaccount.FieldURI: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field uri", values[i]) + } else if value.Valid { + da.URI = value.String + } + case davaccount.FieldPassword: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field password", values[i]) + } else if value.Valid { + da.Password = value.String + } + case davaccount.FieldOptions: + if value, ok := values[i].(*boolset.BooleanSet); !ok { + return fmt.Errorf("unexpected type %T for field options", values[i]) + } else if value != nil { + da.Options = value + } + case davaccount.FieldProps: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field props", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &da.Props); err != nil { + return fmt.Errorf("unmarshal field props: %w", err) + } + } + case davaccount.FieldOwnerID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field owner_id", values[i]) + } else if value.Valid { + da.OwnerID = int(value.Int64) + } + default: + da.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the DavAccount. +// This includes values selected through modifiers, order, etc. +func (da *DavAccount) Value(name string) (ent.Value, error) { + return da.selectValues.Get(name) +} + +// QueryOwner queries the "owner" edge of the DavAccount entity. +func (da *DavAccount) QueryOwner() *UserQuery { + return NewDavAccountClient(da.config).QueryOwner(da) +} + +// Update returns a builder for updating this DavAccount. +// Note that you need to call DavAccount.Unwrap() before calling this method if this DavAccount +// was returned from a transaction, and the transaction was committed or rolled back. +func (da *DavAccount) Update() *DavAccountUpdateOne { + return NewDavAccountClient(da.config).UpdateOne(da) +} + +// Unwrap unwraps the DavAccount entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (da *DavAccount) Unwrap() *DavAccount { + _tx, ok := da.config.driver.(*txDriver) + if !ok { + panic("ent: DavAccount is not a transactional entity") + } + da.config.driver = _tx.drv + return da +} + +// String implements the fmt.Stringer. +func (da *DavAccount) String() string { + var builder strings.Builder + builder.WriteString("DavAccount(") + builder.WriteString(fmt.Sprintf("id=%v, ", da.ID)) + builder.WriteString("created_at=") + builder.WriteString(da.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(da.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := da.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(da.Name) + builder.WriteString(", ") + builder.WriteString("uri=") + builder.WriteString(da.URI) + builder.WriteString(", ") + builder.WriteString("password=") + builder.WriteString(", ") + builder.WriteString("options=") + builder.WriteString(fmt.Sprintf("%v", da.Options)) + builder.WriteString(", ") + builder.WriteString("props=") + builder.WriteString(fmt.Sprintf("%v", da.Props)) + builder.WriteString(", ") + builder.WriteString("owner_id=") + builder.WriteString(fmt.Sprintf("%v", da.OwnerID)) + builder.WriteByte(')') + return builder.String() +} + +// SetOwner manually set the edge as loaded state. +func (e *DavAccount) SetOwner(v *User) { + e.Edges.Owner = v + e.Edges.loadedTypes[0] = true +} + +// DavAccounts is a parsable slice of DavAccount. +type DavAccounts []*DavAccount diff --git a/ent/davaccount/davaccount.go b/ent/davaccount/davaccount.go new file mode 100644 index 00000000..169b331a --- /dev/null +++ b/ent/davaccount/davaccount.go @@ -0,0 +1,144 @@ +// Code generated by ent, DO NOT EDIT. + +package davaccount + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the davaccount type in the database. + Label = "dav_account" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldURI holds the string denoting the uri field in the database. + FieldURI = "uri" + // FieldPassword holds the string denoting the password field in the database. + FieldPassword = "password" + // FieldOptions holds the string denoting the options field in the database. + FieldOptions = "options" + // FieldProps holds the string denoting the props field in the database. + FieldProps = "props" + // FieldOwnerID holds the string denoting the owner_id field in the database. + FieldOwnerID = "owner_id" + // EdgeOwner holds the string denoting the owner edge name in mutations. + EdgeOwner = "owner" + // Table holds the table name of the davaccount in the database. + Table = "dav_accounts" + // OwnerTable is the table that holds the owner relation/edge. + OwnerTable = "dav_accounts" + // OwnerInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + OwnerInverseTable = "users" + // OwnerColumn is the table column denoting the owner relation/edge. + OwnerColumn = "owner_id" +) + +// Columns holds all SQL columns for davaccount fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldURI, + FieldPassword, + FieldOptions, + FieldProps, + FieldOwnerID, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the DavAccount queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByURI orders the results by the uri field. +func ByURI(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldURI, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByOwnerID orders the results by the owner_id field. +func ByOwnerID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOwnerID, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} diff --git a/ent/davaccount/where.go b/ent/davaccount/where.go new file mode 100644 index 00000000..90b82f14 --- /dev/null +++ b/ent/davaccount/where.go @@ -0,0 +1,530 @@ +// Code generated by ent, DO NOT EDIT. + +package davaccount + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldName, v)) +} + +// URI applies equality check predicate on the "uri" field. It's identical to URIEQ. +func URI(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldURI, v)) +} + +// Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. +func Password(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldPassword, v)) +} + +// Options applies equality check predicate on the "options" field. It's identical to OptionsEQ. +func Options(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldOptions, v)) +} + +// OwnerID applies equality check predicate on the "owner_id" field. It's identical to OwnerIDEQ. +func OwnerID(v int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldOwnerID, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.DavAccount { + return predicate.DavAccount(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContainsFold(FieldName, v)) +} + +// URIEQ applies the EQ predicate on the "uri" field. +func URIEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldURI, v)) +} + +// URINEQ applies the NEQ predicate on the "uri" field. +func URINEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldURI, v)) +} + +// URIIn applies the In predicate on the "uri" field. +func URIIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldURI, vs...)) +} + +// URINotIn applies the NotIn predicate on the "uri" field. +func URINotIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldURI, vs...)) +} + +// URIGT applies the GT predicate on the "uri" field. +func URIGT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldURI, v)) +} + +// URIGTE applies the GTE predicate on the "uri" field. +func URIGTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldURI, v)) +} + +// URILT applies the LT predicate on the "uri" field. +func URILT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldURI, v)) +} + +// URILTE applies the LTE predicate on the "uri" field. +func URILTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldURI, v)) +} + +// URIContains applies the Contains predicate on the "uri" field. +func URIContains(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContains(FieldURI, v)) +} + +// URIHasPrefix applies the HasPrefix predicate on the "uri" field. +func URIHasPrefix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasPrefix(FieldURI, v)) +} + +// URIHasSuffix applies the HasSuffix predicate on the "uri" field. +func URIHasSuffix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasSuffix(FieldURI, v)) +} + +// URIEqualFold applies the EqualFold predicate on the "uri" field. +func URIEqualFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEqualFold(FieldURI, v)) +} + +// URIContainsFold applies the ContainsFold predicate on the "uri" field. +func URIContainsFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContainsFold(FieldURI, v)) +} + +// PasswordEQ applies the EQ predicate on the "password" field. +func PasswordEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldPassword, v)) +} + +// PasswordNEQ applies the NEQ predicate on the "password" field. +func PasswordNEQ(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldPassword, v)) +} + +// PasswordIn applies the In predicate on the "password" field. +func PasswordIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldPassword, vs...)) +} + +// PasswordNotIn applies the NotIn predicate on the "password" field. +func PasswordNotIn(vs ...string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldPassword, vs...)) +} + +// PasswordGT applies the GT predicate on the "password" field. +func PasswordGT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldPassword, v)) +} + +// PasswordGTE applies the GTE predicate on the "password" field. +func PasswordGTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldPassword, v)) +} + +// PasswordLT applies the LT predicate on the "password" field. +func PasswordLT(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldPassword, v)) +} + +// PasswordLTE applies the LTE predicate on the "password" field. +func PasswordLTE(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldPassword, v)) +} + +// PasswordContains applies the Contains predicate on the "password" field. +func PasswordContains(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContains(FieldPassword, v)) +} + +// PasswordHasPrefix applies the HasPrefix predicate on the "password" field. +func PasswordHasPrefix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasPrefix(FieldPassword, v)) +} + +// PasswordHasSuffix applies the HasSuffix predicate on the "password" field. +func PasswordHasSuffix(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldHasSuffix(FieldPassword, v)) +} + +// PasswordEqualFold applies the EqualFold predicate on the "password" field. +func PasswordEqualFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEqualFold(FieldPassword, v)) +} + +// PasswordContainsFold applies the ContainsFold predicate on the "password" field. +func PasswordContainsFold(v string) predicate.DavAccount { + return predicate.DavAccount(sql.FieldContainsFold(FieldPassword, v)) +} + +// OptionsEQ applies the EQ predicate on the "options" field. +func OptionsEQ(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldOptions, v)) +} + +// OptionsNEQ applies the NEQ predicate on the "options" field. +func OptionsNEQ(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldOptions, v)) +} + +// OptionsIn applies the In predicate on the "options" field. +func OptionsIn(vs ...*boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldOptions, vs...)) +} + +// OptionsNotIn applies the NotIn predicate on the "options" field. +func OptionsNotIn(vs ...*boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldOptions, vs...)) +} + +// OptionsGT applies the GT predicate on the "options" field. +func OptionsGT(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGT(FieldOptions, v)) +} + +// OptionsGTE applies the GTE predicate on the "options" field. +func OptionsGTE(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldGTE(FieldOptions, v)) +} + +// OptionsLT applies the LT predicate on the "options" field. +func OptionsLT(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLT(FieldOptions, v)) +} + +// OptionsLTE applies the LTE predicate on the "options" field. +func OptionsLTE(v *boolset.BooleanSet) predicate.DavAccount { + return predicate.DavAccount(sql.FieldLTE(FieldOptions, v)) +} + +// PropsIsNil applies the IsNil predicate on the "props" field. +func PropsIsNil() predicate.DavAccount { + return predicate.DavAccount(sql.FieldIsNull(FieldProps)) +} + +// PropsNotNil applies the NotNil predicate on the "props" field. +func PropsNotNil() predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotNull(FieldProps)) +} + +// OwnerIDEQ applies the EQ predicate on the "owner_id" field. +func OwnerIDEQ(v int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldEQ(FieldOwnerID, v)) +} + +// OwnerIDNEQ applies the NEQ predicate on the "owner_id" field. +func OwnerIDNEQ(v int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNEQ(FieldOwnerID, v)) +} + +// OwnerIDIn applies the In predicate on the "owner_id" field. +func OwnerIDIn(vs ...int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldIn(FieldOwnerID, vs...)) +} + +// OwnerIDNotIn applies the NotIn predicate on the "owner_id" field. +func OwnerIDNotIn(vs ...int) predicate.DavAccount { + return predicate.DavAccount(sql.FieldNotIn(FieldOwnerID, vs...)) +} + +// HasOwner applies the HasEdge predicate on the "owner" edge. +func HasOwner() predicate.DavAccount { + return predicate.DavAccount(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). +func HasOwnerWith(preds ...predicate.User) predicate.DavAccount { + return predicate.DavAccount(func(s *sql.Selector) { + step := newOwnerStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.DavAccount) predicate.DavAccount { + return predicate.DavAccount(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.DavAccount) predicate.DavAccount { + return predicate.DavAccount(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.DavAccount) predicate.DavAccount { + return predicate.DavAccount(sql.NotPredicates(p)) +} diff --git a/ent/davaccount_create.go b/ent/davaccount_create.go new file mode 100644 index 00000000..3723fdf5 --- /dev/null +++ b/ent/davaccount_create.go @@ -0,0 +1,968 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// DavAccountCreate is the builder for creating a DavAccount entity. +type DavAccountCreate struct { + config + mutation *DavAccountMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (dac *DavAccountCreate) SetCreatedAt(t time.Time) *DavAccountCreate { + dac.mutation.SetCreatedAt(t) + return dac +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (dac *DavAccountCreate) SetNillableCreatedAt(t *time.Time) *DavAccountCreate { + if t != nil { + dac.SetCreatedAt(*t) + } + return dac +} + +// SetUpdatedAt sets the "updated_at" field. +func (dac *DavAccountCreate) SetUpdatedAt(t time.Time) *DavAccountCreate { + dac.mutation.SetUpdatedAt(t) + return dac +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (dac *DavAccountCreate) SetNillableUpdatedAt(t *time.Time) *DavAccountCreate { + if t != nil { + dac.SetUpdatedAt(*t) + } + return dac +} + +// SetDeletedAt sets the "deleted_at" field. +func (dac *DavAccountCreate) SetDeletedAt(t time.Time) *DavAccountCreate { + dac.mutation.SetDeletedAt(t) + return dac +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dac *DavAccountCreate) SetNillableDeletedAt(t *time.Time) *DavAccountCreate { + if t != nil { + dac.SetDeletedAt(*t) + } + return dac +} + +// SetName sets the "name" field. +func (dac *DavAccountCreate) SetName(s string) *DavAccountCreate { + dac.mutation.SetName(s) + return dac +} + +// SetURI sets the "uri" field. +func (dac *DavAccountCreate) SetURI(s string) *DavAccountCreate { + dac.mutation.SetURI(s) + return dac +} + +// SetPassword sets the "password" field. +func (dac *DavAccountCreate) SetPassword(s string) *DavAccountCreate { + dac.mutation.SetPassword(s) + return dac +} + +// SetOptions sets the "options" field. +func (dac *DavAccountCreate) SetOptions(bs *boolset.BooleanSet) *DavAccountCreate { + dac.mutation.SetOptions(bs) + return dac +} + +// SetProps sets the "props" field. +func (dac *DavAccountCreate) SetProps(tap *types.DavAccountProps) *DavAccountCreate { + dac.mutation.SetProps(tap) + return dac +} + +// SetOwnerID sets the "owner_id" field. +func (dac *DavAccountCreate) SetOwnerID(i int) *DavAccountCreate { + dac.mutation.SetOwnerID(i) + return dac +} + +// SetOwner sets the "owner" edge to the User entity. +func (dac *DavAccountCreate) SetOwner(u *User) *DavAccountCreate { + return dac.SetOwnerID(u.ID) +} + +// Mutation returns the DavAccountMutation object of the builder. +func (dac *DavAccountCreate) Mutation() *DavAccountMutation { + return dac.mutation +} + +// Save creates the DavAccount in the database. +func (dac *DavAccountCreate) Save(ctx context.Context) (*DavAccount, error) { + if err := dac.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, dac.sqlSave, dac.mutation, dac.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (dac *DavAccountCreate) SaveX(ctx context.Context) *DavAccount { + v, err := dac.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (dac *DavAccountCreate) Exec(ctx context.Context) error { + _, err := dac.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dac *DavAccountCreate) ExecX(ctx context.Context) { + if err := dac.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dac *DavAccountCreate) defaults() error { + if _, ok := dac.mutation.CreatedAt(); !ok { + if davaccount.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized davaccount.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := davaccount.DefaultCreatedAt() + dac.mutation.SetCreatedAt(v) + } + if _, ok := dac.mutation.UpdatedAt(); !ok { + if davaccount.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized davaccount.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := davaccount.DefaultUpdatedAt() + dac.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dac *DavAccountCreate) check() error { + if _, ok := dac.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "DavAccount.created_at"`)} + } + if _, ok := dac.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "DavAccount.updated_at"`)} + } + if _, ok := dac.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "DavAccount.name"`)} + } + if _, ok := dac.mutation.URI(); !ok { + return &ValidationError{Name: "uri", err: errors.New(`ent: missing required field "DavAccount.uri"`)} + } + if _, ok := dac.mutation.Password(); !ok { + return &ValidationError{Name: "password", err: errors.New(`ent: missing required field "DavAccount.password"`)} + } + if _, ok := dac.mutation.Options(); !ok { + return &ValidationError{Name: "options", err: errors.New(`ent: missing required field "DavAccount.options"`)} + } + if _, ok := dac.mutation.OwnerID(); !ok { + return &ValidationError{Name: "owner_id", err: errors.New(`ent: missing required field "DavAccount.owner_id"`)} + } + if _, ok := dac.mutation.OwnerID(); !ok { + return &ValidationError{Name: "owner", err: errors.New(`ent: missing required edge "DavAccount.owner"`)} + } + return nil +} + +func (dac *DavAccountCreate) sqlSave(ctx context.Context) (*DavAccount, error) { + if err := dac.check(); err != nil { + return nil, err + } + _node, _spec := dac.createSpec() + if err := sqlgraph.CreateNode(ctx, dac.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + dac.mutation.id = &_node.ID + dac.mutation.done = true + return _node, nil +} + +func (dac *DavAccountCreate) createSpec() (*DavAccount, *sqlgraph.CreateSpec) { + var ( + _node = &DavAccount{config: dac.config} + _spec = sqlgraph.NewCreateSpec(davaccount.Table, sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt)) + ) + + if id, ok := dac.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = dac.conflict + if value, ok := dac.mutation.CreatedAt(); ok { + _spec.SetField(davaccount.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := dac.mutation.UpdatedAt(); ok { + _spec.SetField(davaccount.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := dac.mutation.DeletedAt(); ok { + _spec.SetField(davaccount.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := dac.mutation.Name(); ok { + _spec.SetField(davaccount.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := dac.mutation.URI(); ok { + _spec.SetField(davaccount.FieldURI, field.TypeString, value) + _node.URI = value + } + if value, ok := dac.mutation.Password(); ok { + _spec.SetField(davaccount.FieldPassword, field.TypeString, value) + _node.Password = value + } + if value, ok := dac.mutation.Options(); ok { + _spec.SetField(davaccount.FieldOptions, field.TypeBytes, value) + _node.Options = value + } + if value, ok := dac.mutation.Props(); ok { + _spec.SetField(davaccount.FieldProps, field.TypeJSON, value) + _node.Props = value + } + if nodes := dac.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: davaccount.OwnerTable, + Columns: []string{davaccount.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.OwnerID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.DavAccount.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.DavAccountUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (dac *DavAccountCreate) OnConflict(opts ...sql.ConflictOption) *DavAccountUpsertOne { + dac.conflict = opts + return &DavAccountUpsertOne{ + create: dac, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (dac *DavAccountCreate) OnConflictColumns(columns ...string) *DavAccountUpsertOne { + dac.conflict = append(dac.conflict, sql.ConflictColumns(columns...)) + return &DavAccountUpsertOne{ + create: dac, + } +} + +type ( + // DavAccountUpsertOne is the builder for "upsert"-ing + // one DavAccount node. + DavAccountUpsertOne struct { + create *DavAccountCreate + } + + // DavAccountUpsert is the "OnConflict" setter. + DavAccountUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *DavAccountUpsert) SetUpdatedAt(v time.Time) *DavAccountUpsert { + u.Set(davaccount.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateUpdatedAt() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DavAccountUpsert) SetDeletedAt(v time.Time) *DavAccountUpsert { + u.Set(davaccount.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateDeletedAt() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DavAccountUpsert) ClearDeletedAt() *DavAccountUpsert { + u.SetNull(davaccount.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *DavAccountUpsert) SetName(v string) *DavAccountUpsert { + u.Set(davaccount.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateName() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldName) + return u +} + +// SetURI sets the "uri" field. +func (u *DavAccountUpsert) SetURI(v string) *DavAccountUpsert { + u.Set(davaccount.FieldURI, v) + return u +} + +// UpdateURI sets the "uri" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateURI() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldURI) + return u +} + +// SetPassword sets the "password" field. +func (u *DavAccountUpsert) SetPassword(v string) *DavAccountUpsert { + u.Set(davaccount.FieldPassword, v) + return u +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdatePassword() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldPassword) + return u +} + +// SetOptions sets the "options" field. +func (u *DavAccountUpsert) SetOptions(v *boolset.BooleanSet) *DavAccountUpsert { + u.Set(davaccount.FieldOptions, v) + return u +} + +// UpdateOptions sets the "options" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateOptions() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldOptions) + return u +} + +// SetProps sets the "props" field. +func (u *DavAccountUpsert) SetProps(v *types.DavAccountProps) *DavAccountUpsert { + u.Set(davaccount.FieldProps, v) + return u +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateProps() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldProps) + return u +} + +// ClearProps clears the value of the "props" field. +func (u *DavAccountUpsert) ClearProps() *DavAccountUpsert { + u.SetNull(davaccount.FieldProps) + return u +} + +// SetOwnerID sets the "owner_id" field. +func (u *DavAccountUpsert) SetOwnerID(v int) *DavAccountUpsert { + u.Set(davaccount.FieldOwnerID, v) + return u +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *DavAccountUpsert) UpdateOwnerID() *DavAccountUpsert { + u.SetExcluded(davaccount.FieldOwnerID) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *DavAccountUpsertOne) UpdateNewValues() *DavAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(davaccount.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *DavAccountUpsertOne) Ignore() *DavAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *DavAccountUpsertOne) DoNothing() *DavAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the DavAccountCreate.OnConflict +// documentation for more info. +func (u *DavAccountUpsertOne) Update(set func(*DavAccountUpsert)) *DavAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&DavAccountUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *DavAccountUpsertOne) SetUpdatedAt(v time.Time) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateUpdatedAt() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DavAccountUpsertOne) SetDeletedAt(v time.Time) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateDeletedAt() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DavAccountUpsertOne) ClearDeletedAt() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *DavAccountUpsertOne) SetName(v string) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateName() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateName() + }) +} + +// SetURI sets the "uri" field. +func (u *DavAccountUpsertOne) SetURI(v string) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetURI(v) + }) +} + +// UpdateURI sets the "uri" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateURI() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateURI() + }) +} + +// SetPassword sets the "password" field. +func (u *DavAccountUpsertOne) SetPassword(v string) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdatePassword() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdatePassword() + }) +} + +// SetOptions sets the "options" field. +func (u *DavAccountUpsertOne) SetOptions(v *boolset.BooleanSet) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetOptions(v) + }) +} + +// UpdateOptions sets the "options" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateOptions() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateOptions() + }) +} + +// SetProps sets the "props" field. +func (u *DavAccountUpsertOne) SetProps(v *types.DavAccountProps) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetProps(v) + }) +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateProps() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateProps() + }) +} + +// ClearProps clears the value of the "props" field. +func (u *DavAccountUpsertOne) ClearProps() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.ClearProps() + }) +} + +// SetOwnerID sets the "owner_id" field. +func (u *DavAccountUpsertOne) SetOwnerID(v int) *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.SetOwnerID(v) + }) +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *DavAccountUpsertOne) UpdateOwnerID() *DavAccountUpsertOne { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateOwnerID() + }) +} + +// Exec executes the query. +func (u *DavAccountUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for DavAccountCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *DavAccountUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *DavAccountUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *DavAccountUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *DavAccountCreate) SetRawID(t int) *DavAccountCreate { + m.mutation.SetRawID(t) + return m +} + +// DavAccountCreateBulk is the builder for creating many DavAccount entities in bulk. +type DavAccountCreateBulk struct { + config + err error + builders []*DavAccountCreate + conflict []sql.ConflictOption +} + +// Save creates the DavAccount entities in the database. +func (dacb *DavAccountCreateBulk) Save(ctx context.Context) ([]*DavAccount, error) { + if dacb.err != nil { + return nil, dacb.err + } + specs := make([]*sqlgraph.CreateSpec, len(dacb.builders)) + nodes := make([]*DavAccount, len(dacb.builders)) + mutators := make([]Mutator, len(dacb.builders)) + for i := range dacb.builders { + func(i int, root context.Context) { + builder := dacb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*DavAccountMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, dacb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = dacb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, dacb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, dacb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (dacb *DavAccountCreateBulk) SaveX(ctx context.Context) []*DavAccount { + v, err := dacb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (dacb *DavAccountCreateBulk) Exec(ctx context.Context) error { + _, err := dacb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dacb *DavAccountCreateBulk) ExecX(ctx context.Context) { + if err := dacb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.DavAccount.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.DavAccountUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (dacb *DavAccountCreateBulk) OnConflict(opts ...sql.ConflictOption) *DavAccountUpsertBulk { + dacb.conflict = opts + return &DavAccountUpsertBulk{ + create: dacb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (dacb *DavAccountCreateBulk) OnConflictColumns(columns ...string) *DavAccountUpsertBulk { + dacb.conflict = append(dacb.conflict, sql.ConflictColumns(columns...)) + return &DavAccountUpsertBulk{ + create: dacb, + } +} + +// DavAccountUpsertBulk is the builder for "upsert"-ing +// a bulk of DavAccount nodes. +type DavAccountUpsertBulk struct { + create *DavAccountCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *DavAccountUpsertBulk) UpdateNewValues() *DavAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(davaccount.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.DavAccount.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *DavAccountUpsertBulk) Ignore() *DavAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *DavAccountUpsertBulk) DoNothing() *DavAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the DavAccountCreateBulk.OnConflict +// documentation for more info. +func (u *DavAccountUpsertBulk) Update(set func(*DavAccountUpsert)) *DavAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&DavAccountUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *DavAccountUpsertBulk) SetUpdatedAt(v time.Time) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateUpdatedAt() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DavAccountUpsertBulk) SetDeletedAt(v time.Time) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateDeletedAt() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DavAccountUpsertBulk) ClearDeletedAt() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *DavAccountUpsertBulk) SetName(v string) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateName() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateName() + }) +} + +// SetURI sets the "uri" field. +func (u *DavAccountUpsertBulk) SetURI(v string) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetURI(v) + }) +} + +// UpdateURI sets the "uri" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateURI() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateURI() + }) +} + +// SetPassword sets the "password" field. +func (u *DavAccountUpsertBulk) SetPassword(v string) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdatePassword() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdatePassword() + }) +} + +// SetOptions sets the "options" field. +func (u *DavAccountUpsertBulk) SetOptions(v *boolset.BooleanSet) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetOptions(v) + }) +} + +// UpdateOptions sets the "options" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateOptions() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateOptions() + }) +} + +// SetProps sets the "props" field. +func (u *DavAccountUpsertBulk) SetProps(v *types.DavAccountProps) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetProps(v) + }) +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateProps() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateProps() + }) +} + +// ClearProps clears the value of the "props" field. +func (u *DavAccountUpsertBulk) ClearProps() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.ClearProps() + }) +} + +// SetOwnerID sets the "owner_id" field. +func (u *DavAccountUpsertBulk) SetOwnerID(v int) *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.SetOwnerID(v) + }) +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *DavAccountUpsertBulk) UpdateOwnerID() *DavAccountUpsertBulk { + return u.Update(func(s *DavAccountUpsert) { + s.UpdateOwnerID() + }) +} + +// Exec executes the query. +func (u *DavAccountUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the DavAccountCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for DavAccountCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *DavAccountUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/davaccount_delete.go b/ent/davaccount_delete.go new file mode 100644 index 00000000..849ed3c0 --- /dev/null +++ b/ent/davaccount_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// DavAccountDelete is the builder for deleting a DavAccount entity. +type DavAccountDelete struct { + config + hooks []Hook + mutation *DavAccountMutation +} + +// Where appends a list predicates to the DavAccountDelete builder. +func (dad *DavAccountDelete) Where(ps ...predicate.DavAccount) *DavAccountDelete { + dad.mutation.Where(ps...) + return dad +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (dad *DavAccountDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, dad.sqlExec, dad.mutation, dad.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (dad *DavAccountDelete) ExecX(ctx context.Context) int { + n, err := dad.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (dad *DavAccountDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(davaccount.Table, sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt)) + if ps := dad.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, dad.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + dad.mutation.done = true + return affected, err +} + +// DavAccountDeleteOne is the builder for deleting a single DavAccount entity. +type DavAccountDeleteOne struct { + dad *DavAccountDelete +} + +// Where appends a list predicates to the DavAccountDelete builder. +func (dado *DavAccountDeleteOne) Where(ps ...predicate.DavAccount) *DavAccountDeleteOne { + dado.dad.mutation.Where(ps...) + return dado +} + +// Exec executes the deletion query. +func (dado *DavAccountDeleteOne) Exec(ctx context.Context) error { + n, err := dado.dad.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{davaccount.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (dado *DavAccountDeleteOne) ExecX(ctx context.Context) { + if err := dado.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/davaccount_query.go b/ent/davaccount_query.go new file mode 100644 index 00000000..24de72af --- /dev/null +++ b/ent/davaccount_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// DavAccountQuery is the builder for querying DavAccount entities. +type DavAccountQuery struct { + config + ctx *QueryContext + order []davaccount.OrderOption + inters []Interceptor + predicates []predicate.DavAccount + withOwner *UserQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the DavAccountQuery builder. +func (daq *DavAccountQuery) Where(ps ...predicate.DavAccount) *DavAccountQuery { + daq.predicates = append(daq.predicates, ps...) + return daq +} + +// Limit the number of records to be returned by this query. +func (daq *DavAccountQuery) Limit(limit int) *DavAccountQuery { + daq.ctx.Limit = &limit + return daq +} + +// Offset to start from. +func (daq *DavAccountQuery) Offset(offset int) *DavAccountQuery { + daq.ctx.Offset = &offset + return daq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (daq *DavAccountQuery) Unique(unique bool) *DavAccountQuery { + daq.ctx.Unique = &unique + return daq +} + +// Order specifies how the records should be ordered. +func (daq *DavAccountQuery) Order(o ...davaccount.OrderOption) *DavAccountQuery { + daq.order = append(daq.order, o...) + return daq +} + +// QueryOwner chains the current query on the "owner" edge. +func (daq *DavAccountQuery) QueryOwner() *UserQuery { + query := (&UserClient{config: daq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := daq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := daq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(davaccount.Table, davaccount.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, davaccount.OwnerTable, davaccount.OwnerColumn), + ) + fromU = sqlgraph.SetNeighbors(daq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first DavAccount entity from the query. +// Returns a *NotFoundError when no DavAccount was found. +func (daq *DavAccountQuery) First(ctx context.Context) (*DavAccount, error) { + nodes, err := daq.Limit(1).All(setContextOp(ctx, daq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{davaccount.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (daq *DavAccountQuery) FirstX(ctx context.Context) *DavAccount { + node, err := daq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first DavAccount ID from the query. +// Returns a *NotFoundError when no DavAccount ID was found. +func (daq *DavAccountQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = daq.Limit(1).IDs(setContextOp(ctx, daq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{davaccount.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (daq *DavAccountQuery) FirstIDX(ctx context.Context) int { + id, err := daq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single DavAccount entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one DavAccount entity is found. +// Returns a *NotFoundError when no DavAccount entities are found. +func (daq *DavAccountQuery) Only(ctx context.Context) (*DavAccount, error) { + nodes, err := daq.Limit(2).All(setContextOp(ctx, daq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{davaccount.Label} + default: + return nil, &NotSingularError{davaccount.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (daq *DavAccountQuery) OnlyX(ctx context.Context) *DavAccount { + node, err := daq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only DavAccount ID in the query. +// Returns a *NotSingularError when more than one DavAccount ID is found. +// Returns a *NotFoundError when no entities are found. +func (daq *DavAccountQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = daq.Limit(2).IDs(setContextOp(ctx, daq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{davaccount.Label} + default: + err = &NotSingularError{davaccount.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (daq *DavAccountQuery) OnlyIDX(ctx context.Context) int { + id, err := daq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of DavAccounts. +func (daq *DavAccountQuery) All(ctx context.Context) ([]*DavAccount, error) { + ctx = setContextOp(ctx, daq.ctx, "All") + if err := daq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*DavAccount, *DavAccountQuery]() + return withInterceptors[[]*DavAccount](ctx, daq, qr, daq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (daq *DavAccountQuery) AllX(ctx context.Context) []*DavAccount { + nodes, err := daq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of DavAccount IDs. +func (daq *DavAccountQuery) IDs(ctx context.Context) (ids []int, err error) { + if daq.ctx.Unique == nil && daq.path != nil { + daq.Unique(true) + } + ctx = setContextOp(ctx, daq.ctx, "IDs") + if err = daq.Select(davaccount.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (daq *DavAccountQuery) IDsX(ctx context.Context) []int { + ids, err := daq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (daq *DavAccountQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, daq.ctx, "Count") + if err := daq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, daq, querierCount[*DavAccountQuery](), daq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (daq *DavAccountQuery) CountX(ctx context.Context) int { + count, err := daq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (daq *DavAccountQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, daq.ctx, "Exist") + switch _, err := daq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (daq *DavAccountQuery) ExistX(ctx context.Context) bool { + exist, err := daq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the DavAccountQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (daq *DavAccountQuery) Clone() *DavAccountQuery { + if daq == nil { + return nil + } + return &DavAccountQuery{ + config: daq.config, + ctx: daq.ctx.Clone(), + order: append([]davaccount.OrderOption{}, daq.order...), + inters: append([]Interceptor{}, daq.inters...), + predicates: append([]predicate.DavAccount{}, daq.predicates...), + withOwner: daq.withOwner.Clone(), + // clone intermediate query. + sql: daq.sql.Clone(), + path: daq.path, + } +} + +// WithOwner tells the query-builder to eager-load the nodes that are connected to +// the "owner" edge. The optional arguments are used to configure the query builder of the edge. +func (daq *DavAccountQuery) WithOwner(opts ...func(*UserQuery)) *DavAccountQuery { + query := (&UserClient{config: daq.config}).Query() + for _, opt := range opts { + opt(query) + } + daq.withOwner = query + return daq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.DavAccount.Query(). +// GroupBy(davaccount.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (daq *DavAccountQuery) GroupBy(field string, fields ...string) *DavAccountGroupBy { + daq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DavAccountGroupBy{build: daq} + grbuild.flds = &daq.ctx.Fields + grbuild.label = davaccount.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.DavAccount.Query(). +// Select(davaccount.FieldCreatedAt). +// Scan(ctx, &v) +func (daq *DavAccountQuery) Select(fields ...string) *DavAccountSelect { + daq.ctx.Fields = append(daq.ctx.Fields, fields...) + sbuild := &DavAccountSelect{DavAccountQuery: daq} + sbuild.label = davaccount.Label + sbuild.flds, sbuild.scan = &daq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DavAccountSelect configured with the given aggregations. +func (daq *DavAccountQuery) Aggregate(fns ...AggregateFunc) *DavAccountSelect { + return daq.Select().Aggregate(fns...) +} + +func (daq *DavAccountQuery) prepareQuery(ctx context.Context) error { + for _, inter := range daq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, daq); err != nil { + return err + } + } + } + for _, f := range daq.ctx.Fields { + if !davaccount.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if daq.path != nil { + prev, err := daq.path(ctx) + if err != nil { + return err + } + daq.sql = prev + } + return nil +} + +func (daq *DavAccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*DavAccount, error) { + var ( + nodes = []*DavAccount{} + _spec = daq.querySpec() + loadedTypes = [1]bool{ + daq.withOwner != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*DavAccount).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &DavAccount{config: daq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, daq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := daq.withOwner; query != nil { + if err := daq.loadOwner(ctx, query, nodes, nil, + func(n *DavAccount, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (daq *DavAccountQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*DavAccount, init func(*DavAccount), assign func(*DavAccount, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*DavAccount) + for i := range nodes { + fk := nodes[i].OwnerID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (daq *DavAccountQuery) sqlCount(ctx context.Context) (int, error) { + _spec := daq.querySpec() + _spec.Node.Columns = daq.ctx.Fields + if len(daq.ctx.Fields) > 0 { + _spec.Unique = daq.ctx.Unique != nil && *daq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, daq.driver, _spec) +} + +func (daq *DavAccountQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(davaccount.Table, davaccount.Columns, sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt)) + _spec.From = daq.sql + if unique := daq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if daq.path != nil { + _spec.Unique = true + } + if fields := daq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, davaccount.FieldID) + for i := range fields { + if fields[i] != davaccount.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if daq.withOwner != nil { + _spec.Node.AddColumnOnce(davaccount.FieldOwnerID) + } + } + if ps := daq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := daq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := daq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := daq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (daq *DavAccountQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(daq.driver.Dialect()) + t1 := builder.Table(davaccount.Table) + columns := daq.ctx.Fields + if len(columns) == 0 { + columns = davaccount.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if daq.sql != nil { + selector = daq.sql + selector.Select(selector.Columns(columns...)...) + } + if daq.ctx.Unique != nil && *daq.ctx.Unique { + selector.Distinct() + } + for _, p := range daq.predicates { + p(selector) + } + for _, p := range daq.order { + p(selector) + } + if offset := daq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := daq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// DavAccountGroupBy is the group-by builder for DavAccount entities. +type DavAccountGroupBy struct { + selector + build *DavAccountQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (dagb *DavAccountGroupBy) Aggregate(fns ...AggregateFunc) *DavAccountGroupBy { + dagb.fns = append(dagb.fns, fns...) + return dagb +} + +// Scan applies the selector query and scans the result into the given value. +func (dagb *DavAccountGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dagb.build.ctx, "GroupBy") + if err := dagb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*DavAccountQuery, *DavAccountGroupBy](ctx, dagb.build, dagb, dagb.build.inters, v) +} + +func (dagb *DavAccountGroupBy) sqlScan(ctx context.Context, root *DavAccountQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dagb.fns)) + for _, fn := range dagb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dagb.flds)+len(dagb.fns)) + for _, f := range *dagb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*dagb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := dagb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// DavAccountSelect is the builder for selecting fields of DavAccount entities. +type DavAccountSelect struct { + *DavAccountQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (das *DavAccountSelect) Aggregate(fns ...AggregateFunc) *DavAccountSelect { + das.fns = append(das.fns, fns...) + return das +} + +// Scan applies the selector query and scans the result into the given value. +func (das *DavAccountSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, das.ctx, "Select") + if err := das.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*DavAccountQuery, *DavAccountSelect](ctx, das.DavAccountQuery, das, das.inters, v) +} + +func (das *DavAccountSelect) sqlScan(ctx context.Context, root *DavAccountQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(das.fns)) + for _, fn := range das.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*das.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := das.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/davaccount_update.go b/ent/davaccount_update.go new file mode 100644 index 00000000..c98deb62 --- /dev/null +++ b/ent/davaccount_update.go @@ -0,0 +1,565 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// DavAccountUpdate is the builder for updating DavAccount entities. +type DavAccountUpdate struct { + config + hooks []Hook + mutation *DavAccountMutation +} + +// Where appends a list predicates to the DavAccountUpdate builder. +func (dau *DavAccountUpdate) Where(ps ...predicate.DavAccount) *DavAccountUpdate { + dau.mutation.Where(ps...) + return dau +} + +// SetUpdatedAt sets the "updated_at" field. +func (dau *DavAccountUpdate) SetUpdatedAt(t time.Time) *DavAccountUpdate { + dau.mutation.SetUpdatedAt(t) + return dau +} + +// SetDeletedAt sets the "deleted_at" field. +func (dau *DavAccountUpdate) SetDeletedAt(t time.Time) *DavAccountUpdate { + dau.mutation.SetDeletedAt(t) + return dau +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dau *DavAccountUpdate) SetNillableDeletedAt(t *time.Time) *DavAccountUpdate { + if t != nil { + dau.SetDeletedAt(*t) + } + return dau +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (dau *DavAccountUpdate) ClearDeletedAt() *DavAccountUpdate { + dau.mutation.ClearDeletedAt() + return dau +} + +// SetName sets the "name" field. +func (dau *DavAccountUpdate) SetName(s string) *DavAccountUpdate { + dau.mutation.SetName(s) + return dau +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (dau *DavAccountUpdate) SetNillableName(s *string) *DavAccountUpdate { + if s != nil { + dau.SetName(*s) + } + return dau +} + +// SetURI sets the "uri" field. +func (dau *DavAccountUpdate) SetURI(s string) *DavAccountUpdate { + dau.mutation.SetURI(s) + return dau +} + +// SetNillableURI sets the "uri" field if the given value is not nil. +func (dau *DavAccountUpdate) SetNillableURI(s *string) *DavAccountUpdate { + if s != nil { + dau.SetURI(*s) + } + return dau +} + +// SetPassword sets the "password" field. +func (dau *DavAccountUpdate) SetPassword(s string) *DavAccountUpdate { + dau.mutation.SetPassword(s) + return dau +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (dau *DavAccountUpdate) SetNillablePassword(s *string) *DavAccountUpdate { + if s != nil { + dau.SetPassword(*s) + } + return dau +} + +// SetOptions sets the "options" field. +func (dau *DavAccountUpdate) SetOptions(bs *boolset.BooleanSet) *DavAccountUpdate { + dau.mutation.SetOptions(bs) + return dau +} + +// SetProps sets the "props" field. +func (dau *DavAccountUpdate) SetProps(tap *types.DavAccountProps) *DavAccountUpdate { + dau.mutation.SetProps(tap) + return dau +} + +// ClearProps clears the value of the "props" field. +func (dau *DavAccountUpdate) ClearProps() *DavAccountUpdate { + dau.mutation.ClearProps() + return dau +} + +// SetOwnerID sets the "owner_id" field. +func (dau *DavAccountUpdate) SetOwnerID(i int) *DavAccountUpdate { + dau.mutation.SetOwnerID(i) + return dau +} + +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (dau *DavAccountUpdate) SetNillableOwnerID(i *int) *DavAccountUpdate { + if i != nil { + dau.SetOwnerID(*i) + } + return dau +} + +// SetOwner sets the "owner" edge to the User entity. +func (dau *DavAccountUpdate) SetOwner(u *User) *DavAccountUpdate { + return dau.SetOwnerID(u.ID) +} + +// Mutation returns the DavAccountMutation object of the builder. +func (dau *DavAccountUpdate) Mutation() *DavAccountMutation { + return dau.mutation +} + +// ClearOwner clears the "owner" edge to the User entity. +func (dau *DavAccountUpdate) ClearOwner() *DavAccountUpdate { + dau.mutation.ClearOwner() + return dau +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (dau *DavAccountUpdate) Save(ctx context.Context) (int, error) { + if err := dau.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, dau.sqlSave, dau.mutation, dau.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (dau *DavAccountUpdate) SaveX(ctx context.Context) int { + affected, err := dau.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (dau *DavAccountUpdate) Exec(ctx context.Context) error { + _, err := dau.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dau *DavAccountUpdate) ExecX(ctx context.Context) { + if err := dau.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dau *DavAccountUpdate) defaults() error { + if _, ok := dau.mutation.UpdatedAt(); !ok { + if davaccount.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized davaccount.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := davaccount.UpdateDefaultUpdatedAt() + dau.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dau *DavAccountUpdate) check() error { + if _, ok := dau.mutation.OwnerID(); dau.mutation.OwnerCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "DavAccount.owner"`) + } + return nil +} + +func (dau *DavAccountUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := dau.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(davaccount.Table, davaccount.Columns, sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt)) + if ps := dau.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := dau.mutation.UpdatedAt(); ok { + _spec.SetField(davaccount.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := dau.mutation.DeletedAt(); ok { + _spec.SetField(davaccount.FieldDeletedAt, field.TypeTime, value) + } + if dau.mutation.DeletedAtCleared() { + _spec.ClearField(davaccount.FieldDeletedAt, field.TypeTime) + } + if value, ok := dau.mutation.Name(); ok { + _spec.SetField(davaccount.FieldName, field.TypeString, value) + } + if value, ok := dau.mutation.URI(); ok { + _spec.SetField(davaccount.FieldURI, field.TypeString, value) + } + if value, ok := dau.mutation.Password(); ok { + _spec.SetField(davaccount.FieldPassword, field.TypeString, value) + } + if value, ok := dau.mutation.Options(); ok { + _spec.SetField(davaccount.FieldOptions, field.TypeBytes, value) + } + if value, ok := dau.mutation.Props(); ok { + _spec.SetField(davaccount.FieldProps, field.TypeJSON, value) + } + if dau.mutation.PropsCleared() { + _spec.ClearField(davaccount.FieldProps, field.TypeJSON) + } + if dau.mutation.OwnerCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: davaccount.OwnerTable, + Columns: []string{davaccount.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := dau.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: davaccount.OwnerTable, + Columns: []string{davaccount.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, dau.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{davaccount.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + dau.mutation.done = true + return n, nil +} + +// DavAccountUpdateOne is the builder for updating a single DavAccount entity. +type DavAccountUpdateOne struct { + config + fields []string + hooks []Hook + mutation *DavAccountMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (dauo *DavAccountUpdateOne) SetUpdatedAt(t time.Time) *DavAccountUpdateOne { + dauo.mutation.SetUpdatedAt(t) + return dauo +} + +// SetDeletedAt sets the "deleted_at" field. +func (dauo *DavAccountUpdateOne) SetDeletedAt(t time.Time) *DavAccountUpdateOne { + dauo.mutation.SetDeletedAt(t) + return dauo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dauo *DavAccountUpdateOne) SetNillableDeletedAt(t *time.Time) *DavAccountUpdateOne { + if t != nil { + dauo.SetDeletedAt(*t) + } + return dauo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (dauo *DavAccountUpdateOne) ClearDeletedAt() *DavAccountUpdateOne { + dauo.mutation.ClearDeletedAt() + return dauo +} + +// SetName sets the "name" field. +func (dauo *DavAccountUpdateOne) SetName(s string) *DavAccountUpdateOne { + dauo.mutation.SetName(s) + return dauo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (dauo *DavAccountUpdateOne) SetNillableName(s *string) *DavAccountUpdateOne { + if s != nil { + dauo.SetName(*s) + } + return dauo +} + +// SetURI sets the "uri" field. +func (dauo *DavAccountUpdateOne) SetURI(s string) *DavAccountUpdateOne { + dauo.mutation.SetURI(s) + return dauo +} + +// SetNillableURI sets the "uri" field if the given value is not nil. +func (dauo *DavAccountUpdateOne) SetNillableURI(s *string) *DavAccountUpdateOne { + if s != nil { + dauo.SetURI(*s) + } + return dauo +} + +// SetPassword sets the "password" field. +func (dauo *DavAccountUpdateOne) SetPassword(s string) *DavAccountUpdateOne { + dauo.mutation.SetPassword(s) + return dauo +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (dauo *DavAccountUpdateOne) SetNillablePassword(s *string) *DavAccountUpdateOne { + if s != nil { + dauo.SetPassword(*s) + } + return dauo +} + +// SetOptions sets the "options" field. +func (dauo *DavAccountUpdateOne) SetOptions(bs *boolset.BooleanSet) *DavAccountUpdateOne { + dauo.mutation.SetOptions(bs) + return dauo +} + +// SetProps sets the "props" field. +func (dauo *DavAccountUpdateOne) SetProps(tap *types.DavAccountProps) *DavAccountUpdateOne { + dauo.mutation.SetProps(tap) + return dauo +} + +// ClearProps clears the value of the "props" field. +func (dauo *DavAccountUpdateOne) ClearProps() *DavAccountUpdateOne { + dauo.mutation.ClearProps() + return dauo +} + +// SetOwnerID sets the "owner_id" field. +func (dauo *DavAccountUpdateOne) SetOwnerID(i int) *DavAccountUpdateOne { + dauo.mutation.SetOwnerID(i) + return dauo +} + +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (dauo *DavAccountUpdateOne) SetNillableOwnerID(i *int) *DavAccountUpdateOne { + if i != nil { + dauo.SetOwnerID(*i) + } + return dauo +} + +// SetOwner sets the "owner" edge to the User entity. +func (dauo *DavAccountUpdateOne) SetOwner(u *User) *DavAccountUpdateOne { + return dauo.SetOwnerID(u.ID) +} + +// Mutation returns the DavAccountMutation object of the builder. +func (dauo *DavAccountUpdateOne) Mutation() *DavAccountMutation { + return dauo.mutation +} + +// ClearOwner clears the "owner" edge to the User entity. +func (dauo *DavAccountUpdateOne) ClearOwner() *DavAccountUpdateOne { + dauo.mutation.ClearOwner() + return dauo +} + +// Where appends a list predicates to the DavAccountUpdate builder. +func (dauo *DavAccountUpdateOne) Where(ps ...predicate.DavAccount) *DavAccountUpdateOne { + dauo.mutation.Where(ps...) + return dauo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (dauo *DavAccountUpdateOne) Select(field string, fields ...string) *DavAccountUpdateOne { + dauo.fields = append([]string{field}, fields...) + return dauo +} + +// Save executes the query and returns the updated DavAccount entity. +func (dauo *DavAccountUpdateOne) Save(ctx context.Context) (*DavAccount, error) { + if err := dauo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, dauo.sqlSave, dauo.mutation, dauo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (dauo *DavAccountUpdateOne) SaveX(ctx context.Context) *DavAccount { + node, err := dauo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (dauo *DavAccountUpdateOne) Exec(ctx context.Context) error { + _, err := dauo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dauo *DavAccountUpdateOne) ExecX(ctx context.Context) { + if err := dauo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dauo *DavAccountUpdateOne) defaults() error { + if _, ok := dauo.mutation.UpdatedAt(); !ok { + if davaccount.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized davaccount.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := davaccount.UpdateDefaultUpdatedAt() + dauo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dauo *DavAccountUpdateOne) check() error { + if _, ok := dauo.mutation.OwnerID(); dauo.mutation.OwnerCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "DavAccount.owner"`) + } + return nil +} + +func (dauo *DavAccountUpdateOne) sqlSave(ctx context.Context) (_node *DavAccount, err error) { + if err := dauo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(davaccount.Table, davaccount.Columns, sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt)) + id, ok := dauo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "DavAccount.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := dauo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, davaccount.FieldID) + for _, f := range fields { + if !davaccount.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != davaccount.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := dauo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := dauo.mutation.UpdatedAt(); ok { + _spec.SetField(davaccount.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := dauo.mutation.DeletedAt(); ok { + _spec.SetField(davaccount.FieldDeletedAt, field.TypeTime, value) + } + if dauo.mutation.DeletedAtCleared() { + _spec.ClearField(davaccount.FieldDeletedAt, field.TypeTime) + } + if value, ok := dauo.mutation.Name(); ok { + _spec.SetField(davaccount.FieldName, field.TypeString, value) + } + if value, ok := dauo.mutation.URI(); ok { + _spec.SetField(davaccount.FieldURI, field.TypeString, value) + } + if value, ok := dauo.mutation.Password(); ok { + _spec.SetField(davaccount.FieldPassword, field.TypeString, value) + } + if value, ok := dauo.mutation.Options(); ok { + _spec.SetField(davaccount.FieldOptions, field.TypeBytes, value) + } + if value, ok := dauo.mutation.Props(); ok { + _spec.SetField(davaccount.FieldProps, field.TypeJSON, value) + } + if dauo.mutation.PropsCleared() { + _spec.ClearField(davaccount.FieldProps, field.TypeJSON) + } + if dauo.mutation.OwnerCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: davaccount.OwnerTable, + Columns: []string{davaccount.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := dauo.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: davaccount.OwnerTable, + Columns: []string{davaccount.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &DavAccount{config: dauo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, dauo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{davaccount.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + dauo.mutation.done = true + return _node, nil +} diff --git a/ent/directlink.go b/ent/directlink.go new file mode 100644 index 00000000..dbca0c14 --- /dev/null +++ b/ent/directlink.go @@ -0,0 +1,212 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/file" +) + +// DirectLink is the model entity for the DirectLink schema. +type DirectLink struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Downloads holds the value of the "downloads" field. + Downloads int `json:"downloads,omitempty"` + // FileID holds the value of the "file_id" field. + FileID int `json:"file_id,omitempty"` + // Speed holds the value of the "speed" field. + Speed int `json:"speed,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the DirectLinkQuery when eager-loading is set. + Edges DirectLinkEdges `json:"edges"` + selectValues sql.SelectValues +} + +// DirectLinkEdges holds the relations/edges for other nodes in the graph. +type DirectLinkEdges struct { + // File holds the value of the file edge. + File *File `json:"file,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// FileOrErr returns the File value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e DirectLinkEdges) FileOrErr() (*File, error) { + if e.loadedTypes[0] { + if e.File == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: file.Label} + } + return e.File, nil + } + return nil, &NotLoadedError{edge: "file"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*DirectLink) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case directlink.FieldID, directlink.FieldDownloads, directlink.FieldFileID, directlink.FieldSpeed: + values[i] = new(sql.NullInt64) + case directlink.FieldName: + values[i] = new(sql.NullString) + case directlink.FieldCreatedAt, directlink.FieldUpdatedAt, directlink.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the DirectLink fields. +func (dl *DirectLink) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case directlink.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + dl.ID = int(value.Int64) + case directlink.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + dl.CreatedAt = value.Time + } + case directlink.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + dl.UpdatedAt = value.Time + } + case directlink.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + dl.DeletedAt = new(time.Time) + *dl.DeletedAt = value.Time + } + case directlink.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + dl.Name = value.String + } + case directlink.FieldDownloads: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field downloads", values[i]) + } else if value.Valid { + dl.Downloads = int(value.Int64) + } + case directlink.FieldFileID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field file_id", values[i]) + } else if value.Valid { + dl.FileID = int(value.Int64) + } + case directlink.FieldSpeed: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field speed", values[i]) + } else if value.Valid { + dl.Speed = int(value.Int64) + } + default: + dl.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the DirectLink. +// This includes values selected through modifiers, order, etc. +func (dl *DirectLink) Value(name string) (ent.Value, error) { + return dl.selectValues.Get(name) +} + +// QueryFile queries the "file" edge of the DirectLink entity. +func (dl *DirectLink) QueryFile() *FileQuery { + return NewDirectLinkClient(dl.config).QueryFile(dl) +} + +// Update returns a builder for updating this DirectLink. +// Note that you need to call DirectLink.Unwrap() before calling this method if this DirectLink +// was returned from a transaction, and the transaction was committed or rolled back. +func (dl *DirectLink) Update() *DirectLinkUpdateOne { + return NewDirectLinkClient(dl.config).UpdateOne(dl) +} + +// Unwrap unwraps the DirectLink entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (dl *DirectLink) Unwrap() *DirectLink { + _tx, ok := dl.config.driver.(*txDriver) + if !ok { + panic("ent: DirectLink is not a transactional entity") + } + dl.config.driver = _tx.drv + return dl +} + +// String implements the fmt.Stringer. +func (dl *DirectLink) String() string { + var builder strings.Builder + builder.WriteString("DirectLink(") + builder.WriteString(fmt.Sprintf("id=%v, ", dl.ID)) + builder.WriteString("created_at=") + builder.WriteString(dl.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(dl.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := dl.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(dl.Name) + builder.WriteString(", ") + builder.WriteString("downloads=") + builder.WriteString(fmt.Sprintf("%v", dl.Downloads)) + builder.WriteString(", ") + builder.WriteString("file_id=") + builder.WriteString(fmt.Sprintf("%v", dl.FileID)) + builder.WriteString(", ") + builder.WriteString("speed=") + builder.WriteString(fmt.Sprintf("%v", dl.Speed)) + builder.WriteByte(')') + return builder.String() +} + +// SetFile manually set the edge as loaded state. +func (e *DirectLink) SetFile(v *File) { + e.Edges.File = v + e.Edges.loadedTypes[0] = true +} + +// DirectLinks is a parsable slice of DirectLink. +type DirectLinks []*DirectLink diff --git a/ent/directlink/directlink.go b/ent/directlink/directlink.go new file mode 100644 index 00000000..58672aaa --- /dev/null +++ b/ent/directlink/directlink.go @@ -0,0 +1,138 @@ +// Code generated by ent, DO NOT EDIT. + +package directlink + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the directlink type in the database. + Label = "direct_link" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldDownloads holds the string denoting the downloads field in the database. + FieldDownloads = "downloads" + // FieldFileID holds the string denoting the file_id field in the database. + FieldFileID = "file_id" + // FieldSpeed holds the string denoting the speed field in the database. + FieldSpeed = "speed" + // EdgeFile holds the string denoting the file edge name in mutations. + EdgeFile = "file" + // Table holds the table name of the directlink in the database. + Table = "direct_links" + // FileTable is the table that holds the file relation/edge. + FileTable = "direct_links" + // FileInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FileInverseTable = "files" + // FileColumn is the table column denoting the file relation/edge. + FileColumn = "file_id" +) + +// Columns holds all SQL columns for directlink fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldDownloads, + FieldFileID, + FieldSpeed, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the DirectLink queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByDownloads orders the results by the downloads field. +func ByDownloads(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDownloads, opts...).ToFunc() +} + +// ByFileID orders the results by the file_id field. +func ByFileID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFileID, opts...).ToFunc() +} + +// BySpeed orders the results by the speed field. +func BySpeed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSpeed, opts...).ToFunc() +} + +// ByFileField orders the results by file field. +func ByFileField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFileStep(), sql.OrderByField(field, opts...)) + } +} +func newFileStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FileInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) +} diff --git a/ent/directlink/where.go b/ent/directlink/where.go new file mode 100644 index 00000000..0ac9d216 --- /dev/null +++ b/ent/directlink/where.go @@ -0,0 +1,424 @@ +// Code generated by ent, DO NOT EDIT. + +package directlink + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldName, v)) +} + +// Downloads applies equality check predicate on the "downloads" field. It's identical to DownloadsEQ. +func Downloads(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldDownloads, v)) +} + +// FileID applies equality check predicate on the "file_id" field. It's identical to FileIDEQ. +func FileID(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldFileID, v)) +} + +// Speed applies equality check predicate on the "speed" field. It's identical to SpeedEQ. +func Speed(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldSpeed, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.DirectLink { + return predicate.DirectLink(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.DirectLink { + return predicate.DirectLink(sql.FieldContainsFold(FieldName, v)) +} + +// DownloadsEQ applies the EQ predicate on the "downloads" field. +func DownloadsEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldDownloads, v)) +} + +// DownloadsNEQ applies the NEQ predicate on the "downloads" field. +func DownloadsNEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldDownloads, v)) +} + +// DownloadsIn applies the In predicate on the "downloads" field. +func DownloadsIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldDownloads, vs...)) +} + +// DownloadsNotIn applies the NotIn predicate on the "downloads" field. +func DownloadsNotIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldDownloads, vs...)) +} + +// DownloadsGT applies the GT predicate on the "downloads" field. +func DownloadsGT(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldDownloads, v)) +} + +// DownloadsGTE applies the GTE predicate on the "downloads" field. +func DownloadsGTE(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldDownloads, v)) +} + +// DownloadsLT applies the LT predicate on the "downloads" field. +func DownloadsLT(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldDownloads, v)) +} + +// DownloadsLTE applies the LTE predicate on the "downloads" field. +func DownloadsLTE(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldDownloads, v)) +} + +// FileIDEQ applies the EQ predicate on the "file_id" field. +func FileIDEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldFileID, v)) +} + +// FileIDNEQ applies the NEQ predicate on the "file_id" field. +func FileIDNEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldFileID, v)) +} + +// FileIDIn applies the In predicate on the "file_id" field. +func FileIDIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldFileID, vs...)) +} + +// FileIDNotIn applies the NotIn predicate on the "file_id" field. +func FileIDNotIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldFileID, vs...)) +} + +// SpeedEQ applies the EQ predicate on the "speed" field. +func SpeedEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldEQ(FieldSpeed, v)) +} + +// SpeedNEQ applies the NEQ predicate on the "speed" field. +func SpeedNEQ(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNEQ(FieldSpeed, v)) +} + +// SpeedIn applies the In predicate on the "speed" field. +func SpeedIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldIn(FieldSpeed, vs...)) +} + +// SpeedNotIn applies the NotIn predicate on the "speed" field. +func SpeedNotIn(vs ...int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldNotIn(FieldSpeed, vs...)) +} + +// SpeedGT applies the GT predicate on the "speed" field. +func SpeedGT(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGT(FieldSpeed, v)) +} + +// SpeedGTE applies the GTE predicate on the "speed" field. +func SpeedGTE(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldGTE(FieldSpeed, v)) +} + +// SpeedLT applies the LT predicate on the "speed" field. +func SpeedLT(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLT(FieldSpeed, v)) +} + +// SpeedLTE applies the LTE predicate on the "speed" field. +func SpeedLTE(v int) predicate.DirectLink { + return predicate.DirectLink(sql.FieldLTE(FieldSpeed, v)) +} + +// HasFile applies the HasEdge predicate on the "file" edge. +func HasFile() predicate.DirectLink { + return predicate.DirectLink(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFileWith applies the HasEdge predicate on the "file" edge with a given conditions (other predicates). +func HasFileWith(preds ...predicate.File) predicate.DirectLink { + return predicate.DirectLink(func(s *sql.Selector) { + step := newFileStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.DirectLink) predicate.DirectLink { + return predicate.DirectLink(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.DirectLink) predicate.DirectLink { + return predicate.DirectLink(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.DirectLink) predicate.DirectLink { + return predicate.DirectLink(sql.NotPredicates(p)) +} diff --git a/ent/directlink_create.go b/ent/directlink_create.go new file mode 100644 index 00000000..27ad6ff4 --- /dev/null +++ b/ent/directlink_create.go @@ -0,0 +1,883 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/file" +) + +// DirectLinkCreate is the builder for creating a DirectLink entity. +type DirectLinkCreate struct { + config + mutation *DirectLinkMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (dlc *DirectLinkCreate) SetCreatedAt(t time.Time) *DirectLinkCreate { + dlc.mutation.SetCreatedAt(t) + return dlc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (dlc *DirectLinkCreate) SetNillableCreatedAt(t *time.Time) *DirectLinkCreate { + if t != nil { + dlc.SetCreatedAt(*t) + } + return dlc +} + +// SetUpdatedAt sets the "updated_at" field. +func (dlc *DirectLinkCreate) SetUpdatedAt(t time.Time) *DirectLinkCreate { + dlc.mutation.SetUpdatedAt(t) + return dlc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (dlc *DirectLinkCreate) SetNillableUpdatedAt(t *time.Time) *DirectLinkCreate { + if t != nil { + dlc.SetUpdatedAt(*t) + } + return dlc +} + +// SetDeletedAt sets the "deleted_at" field. +func (dlc *DirectLinkCreate) SetDeletedAt(t time.Time) *DirectLinkCreate { + dlc.mutation.SetDeletedAt(t) + return dlc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dlc *DirectLinkCreate) SetNillableDeletedAt(t *time.Time) *DirectLinkCreate { + if t != nil { + dlc.SetDeletedAt(*t) + } + return dlc +} + +// SetName sets the "name" field. +func (dlc *DirectLinkCreate) SetName(s string) *DirectLinkCreate { + dlc.mutation.SetName(s) + return dlc +} + +// SetDownloads sets the "downloads" field. +func (dlc *DirectLinkCreate) SetDownloads(i int) *DirectLinkCreate { + dlc.mutation.SetDownloads(i) + return dlc +} + +// SetFileID sets the "file_id" field. +func (dlc *DirectLinkCreate) SetFileID(i int) *DirectLinkCreate { + dlc.mutation.SetFileID(i) + return dlc +} + +// SetSpeed sets the "speed" field. +func (dlc *DirectLinkCreate) SetSpeed(i int) *DirectLinkCreate { + dlc.mutation.SetSpeed(i) + return dlc +} + +// SetFile sets the "file" edge to the File entity. +func (dlc *DirectLinkCreate) SetFile(f *File) *DirectLinkCreate { + return dlc.SetFileID(f.ID) +} + +// Mutation returns the DirectLinkMutation object of the builder. +func (dlc *DirectLinkCreate) Mutation() *DirectLinkMutation { + return dlc.mutation +} + +// Save creates the DirectLink in the database. +func (dlc *DirectLinkCreate) Save(ctx context.Context) (*DirectLink, error) { + if err := dlc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, dlc.sqlSave, dlc.mutation, dlc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (dlc *DirectLinkCreate) SaveX(ctx context.Context) *DirectLink { + v, err := dlc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (dlc *DirectLinkCreate) Exec(ctx context.Context) error { + _, err := dlc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dlc *DirectLinkCreate) ExecX(ctx context.Context) { + if err := dlc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dlc *DirectLinkCreate) defaults() error { + if _, ok := dlc.mutation.CreatedAt(); !ok { + if directlink.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized directlink.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := directlink.DefaultCreatedAt() + dlc.mutation.SetCreatedAt(v) + } + if _, ok := dlc.mutation.UpdatedAt(); !ok { + if directlink.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized directlink.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := directlink.DefaultUpdatedAt() + dlc.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dlc *DirectLinkCreate) check() error { + if _, ok := dlc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "DirectLink.created_at"`)} + } + if _, ok := dlc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "DirectLink.updated_at"`)} + } + if _, ok := dlc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "DirectLink.name"`)} + } + if _, ok := dlc.mutation.Downloads(); !ok { + return &ValidationError{Name: "downloads", err: errors.New(`ent: missing required field "DirectLink.downloads"`)} + } + if _, ok := dlc.mutation.FileID(); !ok { + return &ValidationError{Name: "file_id", err: errors.New(`ent: missing required field "DirectLink.file_id"`)} + } + if _, ok := dlc.mutation.Speed(); !ok { + return &ValidationError{Name: "speed", err: errors.New(`ent: missing required field "DirectLink.speed"`)} + } + if _, ok := dlc.mutation.FileID(); !ok { + return &ValidationError{Name: "file", err: errors.New(`ent: missing required edge "DirectLink.file"`)} + } + return nil +} + +func (dlc *DirectLinkCreate) sqlSave(ctx context.Context) (*DirectLink, error) { + if err := dlc.check(); err != nil { + return nil, err + } + _node, _spec := dlc.createSpec() + if err := sqlgraph.CreateNode(ctx, dlc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + dlc.mutation.id = &_node.ID + dlc.mutation.done = true + return _node, nil +} + +func (dlc *DirectLinkCreate) createSpec() (*DirectLink, *sqlgraph.CreateSpec) { + var ( + _node = &DirectLink{config: dlc.config} + _spec = sqlgraph.NewCreateSpec(directlink.Table, sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt)) + ) + + if id, ok := dlc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = dlc.conflict + if value, ok := dlc.mutation.CreatedAt(); ok { + _spec.SetField(directlink.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := dlc.mutation.UpdatedAt(); ok { + _spec.SetField(directlink.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := dlc.mutation.DeletedAt(); ok { + _spec.SetField(directlink.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := dlc.mutation.Name(); ok { + _spec.SetField(directlink.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := dlc.mutation.Downloads(); ok { + _spec.SetField(directlink.FieldDownloads, field.TypeInt, value) + _node.Downloads = value + } + if value, ok := dlc.mutation.Speed(); ok { + _spec.SetField(directlink.FieldSpeed, field.TypeInt, value) + _node.Speed = value + } + if nodes := dlc.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: directlink.FileTable, + Columns: []string{directlink.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.FileID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.DirectLink.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.DirectLinkUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (dlc *DirectLinkCreate) OnConflict(opts ...sql.ConflictOption) *DirectLinkUpsertOne { + dlc.conflict = opts + return &DirectLinkUpsertOne{ + create: dlc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (dlc *DirectLinkCreate) OnConflictColumns(columns ...string) *DirectLinkUpsertOne { + dlc.conflict = append(dlc.conflict, sql.ConflictColumns(columns...)) + return &DirectLinkUpsertOne{ + create: dlc, + } +} + +type ( + // DirectLinkUpsertOne is the builder for "upsert"-ing + // one DirectLink node. + DirectLinkUpsertOne struct { + create *DirectLinkCreate + } + + // DirectLinkUpsert is the "OnConflict" setter. + DirectLinkUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *DirectLinkUpsert) SetUpdatedAt(v time.Time) *DirectLinkUpsert { + u.Set(directlink.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateUpdatedAt() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DirectLinkUpsert) SetDeletedAt(v time.Time) *DirectLinkUpsert { + u.Set(directlink.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateDeletedAt() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DirectLinkUpsert) ClearDeletedAt() *DirectLinkUpsert { + u.SetNull(directlink.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *DirectLinkUpsert) SetName(v string) *DirectLinkUpsert { + u.Set(directlink.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateName() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldName) + return u +} + +// SetDownloads sets the "downloads" field. +func (u *DirectLinkUpsert) SetDownloads(v int) *DirectLinkUpsert { + u.Set(directlink.FieldDownloads, v) + return u +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateDownloads() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldDownloads) + return u +} + +// AddDownloads adds v to the "downloads" field. +func (u *DirectLinkUpsert) AddDownloads(v int) *DirectLinkUpsert { + u.Add(directlink.FieldDownloads, v) + return u +} + +// SetFileID sets the "file_id" field. +func (u *DirectLinkUpsert) SetFileID(v int) *DirectLinkUpsert { + u.Set(directlink.FieldFileID, v) + return u +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateFileID() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldFileID) + return u +} + +// SetSpeed sets the "speed" field. +func (u *DirectLinkUpsert) SetSpeed(v int) *DirectLinkUpsert { + u.Set(directlink.FieldSpeed, v) + return u +} + +// UpdateSpeed sets the "speed" field to the value that was provided on create. +func (u *DirectLinkUpsert) UpdateSpeed() *DirectLinkUpsert { + u.SetExcluded(directlink.FieldSpeed) + return u +} + +// AddSpeed adds v to the "speed" field. +func (u *DirectLinkUpsert) AddSpeed(v int) *DirectLinkUpsert { + u.Add(directlink.FieldSpeed, v) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *DirectLinkUpsertOne) UpdateNewValues() *DirectLinkUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(directlink.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *DirectLinkUpsertOne) Ignore() *DirectLinkUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *DirectLinkUpsertOne) DoNothing() *DirectLinkUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the DirectLinkCreate.OnConflict +// documentation for more info. +func (u *DirectLinkUpsertOne) Update(set func(*DirectLinkUpsert)) *DirectLinkUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&DirectLinkUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *DirectLinkUpsertOne) SetUpdatedAt(v time.Time) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateUpdatedAt() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DirectLinkUpsertOne) SetDeletedAt(v time.Time) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateDeletedAt() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DirectLinkUpsertOne) ClearDeletedAt() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *DirectLinkUpsertOne) SetName(v string) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateName() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateName() + }) +} + +// SetDownloads sets the "downloads" field. +func (u *DirectLinkUpsertOne) SetDownloads(v int) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetDownloads(v) + }) +} + +// AddDownloads adds v to the "downloads" field. +func (u *DirectLinkUpsertOne) AddDownloads(v int) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.AddDownloads(v) + }) +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateDownloads() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateDownloads() + }) +} + +// SetFileID sets the "file_id" field. +func (u *DirectLinkUpsertOne) SetFileID(v int) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetFileID(v) + }) +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateFileID() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateFileID() + }) +} + +// SetSpeed sets the "speed" field. +func (u *DirectLinkUpsertOne) SetSpeed(v int) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.SetSpeed(v) + }) +} + +// AddSpeed adds v to the "speed" field. +func (u *DirectLinkUpsertOne) AddSpeed(v int) *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.AddSpeed(v) + }) +} + +// UpdateSpeed sets the "speed" field to the value that was provided on create. +func (u *DirectLinkUpsertOne) UpdateSpeed() *DirectLinkUpsertOne { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateSpeed() + }) +} + +// Exec executes the query. +func (u *DirectLinkUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for DirectLinkCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *DirectLinkUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *DirectLinkUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *DirectLinkUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *DirectLinkCreate) SetRawID(t int) *DirectLinkCreate { + m.mutation.SetRawID(t) + return m +} + +// DirectLinkCreateBulk is the builder for creating many DirectLink entities in bulk. +type DirectLinkCreateBulk struct { + config + err error + builders []*DirectLinkCreate + conflict []sql.ConflictOption +} + +// Save creates the DirectLink entities in the database. +func (dlcb *DirectLinkCreateBulk) Save(ctx context.Context) ([]*DirectLink, error) { + if dlcb.err != nil { + return nil, dlcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(dlcb.builders)) + nodes := make([]*DirectLink, len(dlcb.builders)) + mutators := make([]Mutator, len(dlcb.builders)) + for i := range dlcb.builders { + func(i int, root context.Context) { + builder := dlcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*DirectLinkMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, dlcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = dlcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, dlcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, dlcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (dlcb *DirectLinkCreateBulk) SaveX(ctx context.Context) []*DirectLink { + v, err := dlcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (dlcb *DirectLinkCreateBulk) Exec(ctx context.Context) error { + _, err := dlcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dlcb *DirectLinkCreateBulk) ExecX(ctx context.Context) { + if err := dlcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.DirectLink.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.DirectLinkUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (dlcb *DirectLinkCreateBulk) OnConflict(opts ...sql.ConflictOption) *DirectLinkUpsertBulk { + dlcb.conflict = opts + return &DirectLinkUpsertBulk{ + create: dlcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (dlcb *DirectLinkCreateBulk) OnConflictColumns(columns ...string) *DirectLinkUpsertBulk { + dlcb.conflict = append(dlcb.conflict, sql.ConflictColumns(columns...)) + return &DirectLinkUpsertBulk{ + create: dlcb, + } +} + +// DirectLinkUpsertBulk is the builder for "upsert"-ing +// a bulk of DirectLink nodes. +type DirectLinkUpsertBulk struct { + create *DirectLinkCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *DirectLinkUpsertBulk) UpdateNewValues() *DirectLinkUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(directlink.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.DirectLink.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *DirectLinkUpsertBulk) Ignore() *DirectLinkUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *DirectLinkUpsertBulk) DoNothing() *DirectLinkUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the DirectLinkCreateBulk.OnConflict +// documentation for more info. +func (u *DirectLinkUpsertBulk) Update(set func(*DirectLinkUpsert)) *DirectLinkUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&DirectLinkUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *DirectLinkUpsertBulk) SetUpdatedAt(v time.Time) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateUpdatedAt() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *DirectLinkUpsertBulk) SetDeletedAt(v time.Time) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateDeletedAt() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *DirectLinkUpsertBulk) ClearDeletedAt() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *DirectLinkUpsertBulk) SetName(v string) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateName() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateName() + }) +} + +// SetDownloads sets the "downloads" field. +func (u *DirectLinkUpsertBulk) SetDownloads(v int) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetDownloads(v) + }) +} + +// AddDownloads adds v to the "downloads" field. +func (u *DirectLinkUpsertBulk) AddDownloads(v int) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.AddDownloads(v) + }) +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateDownloads() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateDownloads() + }) +} + +// SetFileID sets the "file_id" field. +func (u *DirectLinkUpsertBulk) SetFileID(v int) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetFileID(v) + }) +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateFileID() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateFileID() + }) +} + +// SetSpeed sets the "speed" field. +func (u *DirectLinkUpsertBulk) SetSpeed(v int) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.SetSpeed(v) + }) +} + +// AddSpeed adds v to the "speed" field. +func (u *DirectLinkUpsertBulk) AddSpeed(v int) *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.AddSpeed(v) + }) +} + +// UpdateSpeed sets the "speed" field to the value that was provided on create. +func (u *DirectLinkUpsertBulk) UpdateSpeed() *DirectLinkUpsertBulk { + return u.Update(func(s *DirectLinkUpsert) { + s.UpdateSpeed() + }) +} + +// Exec executes the query. +func (u *DirectLinkUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the DirectLinkCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for DirectLinkCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *DirectLinkUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/directlink_delete.go b/ent/directlink_delete.go new file mode 100644 index 00000000..9dca19e9 --- /dev/null +++ b/ent/directlink_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// DirectLinkDelete is the builder for deleting a DirectLink entity. +type DirectLinkDelete struct { + config + hooks []Hook + mutation *DirectLinkMutation +} + +// Where appends a list predicates to the DirectLinkDelete builder. +func (dld *DirectLinkDelete) Where(ps ...predicate.DirectLink) *DirectLinkDelete { + dld.mutation.Where(ps...) + return dld +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (dld *DirectLinkDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, dld.sqlExec, dld.mutation, dld.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (dld *DirectLinkDelete) ExecX(ctx context.Context) int { + n, err := dld.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (dld *DirectLinkDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(directlink.Table, sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt)) + if ps := dld.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, dld.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + dld.mutation.done = true + return affected, err +} + +// DirectLinkDeleteOne is the builder for deleting a single DirectLink entity. +type DirectLinkDeleteOne struct { + dld *DirectLinkDelete +} + +// Where appends a list predicates to the DirectLinkDelete builder. +func (dldo *DirectLinkDeleteOne) Where(ps ...predicate.DirectLink) *DirectLinkDeleteOne { + dldo.dld.mutation.Where(ps...) + return dldo +} + +// Exec executes the deletion query. +func (dldo *DirectLinkDeleteOne) Exec(ctx context.Context) error { + n, err := dldo.dld.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{directlink.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (dldo *DirectLinkDeleteOne) ExecX(ctx context.Context) { + if err := dldo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/directlink_query.go b/ent/directlink_query.go new file mode 100644 index 00000000..21292a9f --- /dev/null +++ b/ent/directlink_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// DirectLinkQuery is the builder for querying DirectLink entities. +type DirectLinkQuery struct { + config + ctx *QueryContext + order []directlink.OrderOption + inters []Interceptor + predicates []predicate.DirectLink + withFile *FileQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the DirectLinkQuery builder. +func (dlq *DirectLinkQuery) Where(ps ...predicate.DirectLink) *DirectLinkQuery { + dlq.predicates = append(dlq.predicates, ps...) + return dlq +} + +// Limit the number of records to be returned by this query. +func (dlq *DirectLinkQuery) Limit(limit int) *DirectLinkQuery { + dlq.ctx.Limit = &limit + return dlq +} + +// Offset to start from. +func (dlq *DirectLinkQuery) Offset(offset int) *DirectLinkQuery { + dlq.ctx.Offset = &offset + return dlq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (dlq *DirectLinkQuery) Unique(unique bool) *DirectLinkQuery { + dlq.ctx.Unique = &unique + return dlq +} + +// Order specifies how the records should be ordered. +func (dlq *DirectLinkQuery) Order(o ...directlink.OrderOption) *DirectLinkQuery { + dlq.order = append(dlq.order, o...) + return dlq +} + +// QueryFile chains the current query on the "file" edge. +func (dlq *DirectLinkQuery) QueryFile() *FileQuery { + query := (&FileClient{config: dlq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := dlq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := dlq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(directlink.Table, directlink.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, directlink.FileTable, directlink.FileColumn), + ) + fromU = sqlgraph.SetNeighbors(dlq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first DirectLink entity from the query. +// Returns a *NotFoundError when no DirectLink was found. +func (dlq *DirectLinkQuery) First(ctx context.Context) (*DirectLink, error) { + nodes, err := dlq.Limit(1).All(setContextOp(ctx, dlq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{directlink.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (dlq *DirectLinkQuery) FirstX(ctx context.Context) *DirectLink { + node, err := dlq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first DirectLink ID from the query. +// Returns a *NotFoundError when no DirectLink ID was found. +func (dlq *DirectLinkQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = dlq.Limit(1).IDs(setContextOp(ctx, dlq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{directlink.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (dlq *DirectLinkQuery) FirstIDX(ctx context.Context) int { + id, err := dlq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single DirectLink entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one DirectLink entity is found. +// Returns a *NotFoundError when no DirectLink entities are found. +func (dlq *DirectLinkQuery) Only(ctx context.Context) (*DirectLink, error) { + nodes, err := dlq.Limit(2).All(setContextOp(ctx, dlq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{directlink.Label} + default: + return nil, &NotSingularError{directlink.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (dlq *DirectLinkQuery) OnlyX(ctx context.Context) *DirectLink { + node, err := dlq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only DirectLink ID in the query. +// Returns a *NotSingularError when more than one DirectLink ID is found. +// Returns a *NotFoundError when no entities are found. +func (dlq *DirectLinkQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = dlq.Limit(2).IDs(setContextOp(ctx, dlq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{directlink.Label} + default: + err = &NotSingularError{directlink.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (dlq *DirectLinkQuery) OnlyIDX(ctx context.Context) int { + id, err := dlq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of DirectLinks. +func (dlq *DirectLinkQuery) All(ctx context.Context) ([]*DirectLink, error) { + ctx = setContextOp(ctx, dlq.ctx, "All") + if err := dlq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*DirectLink, *DirectLinkQuery]() + return withInterceptors[[]*DirectLink](ctx, dlq, qr, dlq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (dlq *DirectLinkQuery) AllX(ctx context.Context) []*DirectLink { + nodes, err := dlq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of DirectLink IDs. +func (dlq *DirectLinkQuery) IDs(ctx context.Context) (ids []int, err error) { + if dlq.ctx.Unique == nil && dlq.path != nil { + dlq.Unique(true) + } + ctx = setContextOp(ctx, dlq.ctx, "IDs") + if err = dlq.Select(directlink.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (dlq *DirectLinkQuery) IDsX(ctx context.Context) []int { + ids, err := dlq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (dlq *DirectLinkQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dlq.ctx, "Count") + if err := dlq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, dlq, querierCount[*DirectLinkQuery](), dlq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (dlq *DirectLinkQuery) CountX(ctx context.Context) int { + count, err := dlq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (dlq *DirectLinkQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, dlq.ctx, "Exist") + switch _, err := dlq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (dlq *DirectLinkQuery) ExistX(ctx context.Context) bool { + exist, err := dlq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the DirectLinkQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (dlq *DirectLinkQuery) Clone() *DirectLinkQuery { + if dlq == nil { + return nil + } + return &DirectLinkQuery{ + config: dlq.config, + ctx: dlq.ctx.Clone(), + order: append([]directlink.OrderOption{}, dlq.order...), + inters: append([]Interceptor{}, dlq.inters...), + predicates: append([]predicate.DirectLink{}, dlq.predicates...), + withFile: dlq.withFile.Clone(), + // clone intermediate query. + sql: dlq.sql.Clone(), + path: dlq.path, + } +} + +// WithFile tells the query-builder to eager-load the nodes that are connected to +// the "file" edge. The optional arguments are used to configure the query builder of the edge. +func (dlq *DirectLinkQuery) WithFile(opts ...func(*FileQuery)) *DirectLinkQuery { + query := (&FileClient{config: dlq.config}).Query() + for _, opt := range opts { + opt(query) + } + dlq.withFile = query + return dlq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.DirectLink.Query(). +// GroupBy(directlink.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (dlq *DirectLinkQuery) GroupBy(field string, fields ...string) *DirectLinkGroupBy { + dlq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DirectLinkGroupBy{build: dlq} + grbuild.flds = &dlq.ctx.Fields + grbuild.label = directlink.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.DirectLink.Query(). +// Select(directlink.FieldCreatedAt). +// Scan(ctx, &v) +func (dlq *DirectLinkQuery) Select(fields ...string) *DirectLinkSelect { + dlq.ctx.Fields = append(dlq.ctx.Fields, fields...) + sbuild := &DirectLinkSelect{DirectLinkQuery: dlq} + sbuild.label = directlink.Label + sbuild.flds, sbuild.scan = &dlq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DirectLinkSelect configured with the given aggregations. +func (dlq *DirectLinkQuery) Aggregate(fns ...AggregateFunc) *DirectLinkSelect { + return dlq.Select().Aggregate(fns...) +} + +func (dlq *DirectLinkQuery) prepareQuery(ctx context.Context) error { + for _, inter := range dlq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dlq); err != nil { + return err + } + } + } + for _, f := range dlq.ctx.Fields { + if !directlink.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if dlq.path != nil { + prev, err := dlq.path(ctx) + if err != nil { + return err + } + dlq.sql = prev + } + return nil +} + +func (dlq *DirectLinkQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*DirectLink, error) { + var ( + nodes = []*DirectLink{} + _spec = dlq.querySpec() + loadedTypes = [1]bool{ + dlq.withFile != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*DirectLink).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &DirectLink{config: dlq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, dlq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := dlq.withFile; query != nil { + if err := dlq.loadFile(ctx, query, nodes, nil, + func(n *DirectLink, e *File) { n.Edges.File = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (dlq *DirectLinkQuery) loadFile(ctx context.Context, query *FileQuery, nodes []*DirectLink, init func(*DirectLink), assign func(*DirectLink, *File)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*DirectLink) + for i := range nodes { + fk := nodes[i].FileID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(file.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (dlq *DirectLinkQuery) sqlCount(ctx context.Context) (int, error) { + _spec := dlq.querySpec() + _spec.Node.Columns = dlq.ctx.Fields + if len(dlq.ctx.Fields) > 0 { + _spec.Unique = dlq.ctx.Unique != nil && *dlq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, dlq.driver, _spec) +} + +func (dlq *DirectLinkQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(directlink.Table, directlink.Columns, sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt)) + _spec.From = dlq.sql + if unique := dlq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if dlq.path != nil { + _spec.Unique = true + } + if fields := dlq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, directlink.FieldID) + for i := range fields { + if fields[i] != directlink.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if dlq.withFile != nil { + _spec.Node.AddColumnOnce(directlink.FieldFileID) + } + } + if ps := dlq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := dlq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := dlq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := dlq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (dlq *DirectLinkQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(dlq.driver.Dialect()) + t1 := builder.Table(directlink.Table) + columns := dlq.ctx.Fields + if len(columns) == 0 { + columns = directlink.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if dlq.sql != nil { + selector = dlq.sql + selector.Select(selector.Columns(columns...)...) + } + if dlq.ctx.Unique != nil && *dlq.ctx.Unique { + selector.Distinct() + } + for _, p := range dlq.predicates { + p(selector) + } + for _, p := range dlq.order { + p(selector) + } + if offset := dlq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := dlq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// DirectLinkGroupBy is the group-by builder for DirectLink entities. +type DirectLinkGroupBy struct { + selector + build *DirectLinkQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (dlgb *DirectLinkGroupBy) Aggregate(fns ...AggregateFunc) *DirectLinkGroupBy { + dlgb.fns = append(dlgb.fns, fns...) + return dlgb +} + +// Scan applies the selector query and scans the result into the given value. +func (dlgb *DirectLinkGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dlgb.build.ctx, "GroupBy") + if err := dlgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*DirectLinkQuery, *DirectLinkGroupBy](ctx, dlgb.build, dlgb, dlgb.build.inters, v) +} + +func (dlgb *DirectLinkGroupBy) sqlScan(ctx context.Context, root *DirectLinkQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dlgb.fns)) + for _, fn := range dlgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dlgb.flds)+len(dlgb.fns)) + for _, f := range *dlgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*dlgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := dlgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// DirectLinkSelect is the builder for selecting fields of DirectLink entities. +type DirectLinkSelect struct { + *DirectLinkQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (dls *DirectLinkSelect) Aggregate(fns ...AggregateFunc) *DirectLinkSelect { + dls.fns = append(dls.fns, fns...) + return dls +} + +// Scan applies the selector query and scans the result into the given value. +func (dls *DirectLinkSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dls.ctx, "Select") + if err := dls.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*DirectLinkQuery, *DirectLinkSelect](ctx, dls.DirectLinkQuery, dls, dls.inters, v) +} + +func (dls *DirectLinkSelect) sqlScan(ctx context.Context, root *DirectLinkQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(dls.fns)) + for _, fn := range dls.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*dls.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := dls.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/directlink_update.go b/ent/directlink_update.go new file mode 100644 index 00000000..2d80f90a --- /dev/null +++ b/ent/directlink_update.go @@ -0,0 +1,549 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// DirectLinkUpdate is the builder for updating DirectLink entities. +type DirectLinkUpdate struct { + config + hooks []Hook + mutation *DirectLinkMutation +} + +// Where appends a list predicates to the DirectLinkUpdate builder. +func (dlu *DirectLinkUpdate) Where(ps ...predicate.DirectLink) *DirectLinkUpdate { + dlu.mutation.Where(ps...) + return dlu +} + +// SetUpdatedAt sets the "updated_at" field. +func (dlu *DirectLinkUpdate) SetUpdatedAt(t time.Time) *DirectLinkUpdate { + dlu.mutation.SetUpdatedAt(t) + return dlu +} + +// SetDeletedAt sets the "deleted_at" field. +func (dlu *DirectLinkUpdate) SetDeletedAt(t time.Time) *DirectLinkUpdate { + dlu.mutation.SetDeletedAt(t) + return dlu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dlu *DirectLinkUpdate) SetNillableDeletedAt(t *time.Time) *DirectLinkUpdate { + if t != nil { + dlu.SetDeletedAt(*t) + } + return dlu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (dlu *DirectLinkUpdate) ClearDeletedAt() *DirectLinkUpdate { + dlu.mutation.ClearDeletedAt() + return dlu +} + +// SetName sets the "name" field. +func (dlu *DirectLinkUpdate) SetName(s string) *DirectLinkUpdate { + dlu.mutation.SetName(s) + return dlu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (dlu *DirectLinkUpdate) SetNillableName(s *string) *DirectLinkUpdate { + if s != nil { + dlu.SetName(*s) + } + return dlu +} + +// SetDownloads sets the "downloads" field. +func (dlu *DirectLinkUpdate) SetDownloads(i int) *DirectLinkUpdate { + dlu.mutation.ResetDownloads() + dlu.mutation.SetDownloads(i) + return dlu +} + +// SetNillableDownloads sets the "downloads" field if the given value is not nil. +func (dlu *DirectLinkUpdate) SetNillableDownloads(i *int) *DirectLinkUpdate { + if i != nil { + dlu.SetDownloads(*i) + } + return dlu +} + +// AddDownloads adds i to the "downloads" field. +func (dlu *DirectLinkUpdate) AddDownloads(i int) *DirectLinkUpdate { + dlu.mutation.AddDownloads(i) + return dlu +} + +// SetFileID sets the "file_id" field. +func (dlu *DirectLinkUpdate) SetFileID(i int) *DirectLinkUpdate { + dlu.mutation.SetFileID(i) + return dlu +} + +// SetNillableFileID sets the "file_id" field if the given value is not nil. +func (dlu *DirectLinkUpdate) SetNillableFileID(i *int) *DirectLinkUpdate { + if i != nil { + dlu.SetFileID(*i) + } + return dlu +} + +// SetSpeed sets the "speed" field. +func (dlu *DirectLinkUpdate) SetSpeed(i int) *DirectLinkUpdate { + dlu.mutation.ResetSpeed() + dlu.mutation.SetSpeed(i) + return dlu +} + +// SetNillableSpeed sets the "speed" field if the given value is not nil. +func (dlu *DirectLinkUpdate) SetNillableSpeed(i *int) *DirectLinkUpdate { + if i != nil { + dlu.SetSpeed(*i) + } + return dlu +} + +// AddSpeed adds i to the "speed" field. +func (dlu *DirectLinkUpdate) AddSpeed(i int) *DirectLinkUpdate { + dlu.mutation.AddSpeed(i) + return dlu +} + +// SetFile sets the "file" edge to the File entity. +func (dlu *DirectLinkUpdate) SetFile(f *File) *DirectLinkUpdate { + return dlu.SetFileID(f.ID) +} + +// Mutation returns the DirectLinkMutation object of the builder. +func (dlu *DirectLinkUpdate) Mutation() *DirectLinkMutation { + return dlu.mutation +} + +// ClearFile clears the "file" edge to the File entity. +func (dlu *DirectLinkUpdate) ClearFile() *DirectLinkUpdate { + dlu.mutation.ClearFile() + return dlu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (dlu *DirectLinkUpdate) Save(ctx context.Context) (int, error) { + if err := dlu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, dlu.sqlSave, dlu.mutation, dlu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (dlu *DirectLinkUpdate) SaveX(ctx context.Context) int { + affected, err := dlu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (dlu *DirectLinkUpdate) Exec(ctx context.Context) error { + _, err := dlu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dlu *DirectLinkUpdate) ExecX(ctx context.Context) { + if err := dlu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dlu *DirectLinkUpdate) defaults() error { + if _, ok := dlu.mutation.UpdatedAt(); !ok { + if directlink.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized directlink.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := directlink.UpdateDefaultUpdatedAt() + dlu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dlu *DirectLinkUpdate) check() error { + if _, ok := dlu.mutation.FileID(); dlu.mutation.FileCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "DirectLink.file"`) + } + return nil +} + +func (dlu *DirectLinkUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := dlu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(directlink.Table, directlink.Columns, sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt)) + if ps := dlu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := dlu.mutation.UpdatedAt(); ok { + _spec.SetField(directlink.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := dlu.mutation.DeletedAt(); ok { + _spec.SetField(directlink.FieldDeletedAt, field.TypeTime, value) + } + if dlu.mutation.DeletedAtCleared() { + _spec.ClearField(directlink.FieldDeletedAt, field.TypeTime) + } + if value, ok := dlu.mutation.Name(); ok { + _spec.SetField(directlink.FieldName, field.TypeString, value) + } + if value, ok := dlu.mutation.Downloads(); ok { + _spec.SetField(directlink.FieldDownloads, field.TypeInt, value) + } + if value, ok := dlu.mutation.AddedDownloads(); ok { + _spec.AddField(directlink.FieldDownloads, field.TypeInt, value) + } + if value, ok := dlu.mutation.Speed(); ok { + _spec.SetField(directlink.FieldSpeed, field.TypeInt, value) + } + if value, ok := dlu.mutation.AddedSpeed(); ok { + _spec.AddField(directlink.FieldSpeed, field.TypeInt, value) + } + if dlu.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: directlink.FileTable, + Columns: []string{directlink.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := dlu.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: directlink.FileTable, + Columns: []string{directlink.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, dlu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{directlink.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + dlu.mutation.done = true + return n, nil +} + +// DirectLinkUpdateOne is the builder for updating a single DirectLink entity. +type DirectLinkUpdateOne struct { + config + fields []string + hooks []Hook + mutation *DirectLinkMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (dluo *DirectLinkUpdateOne) SetUpdatedAt(t time.Time) *DirectLinkUpdateOne { + dluo.mutation.SetUpdatedAt(t) + return dluo +} + +// SetDeletedAt sets the "deleted_at" field. +func (dluo *DirectLinkUpdateOne) SetDeletedAt(t time.Time) *DirectLinkUpdateOne { + dluo.mutation.SetDeletedAt(t) + return dluo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (dluo *DirectLinkUpdateOne) SetNillableDeletedAt(t *time.Time) *DirectLinkUpdateOne { + if t != nil { + dluo.SetDeletedAt(*t) + } + return dluo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (dluo *DirectLinkUpdateOne) ClearDeletedAt() *DirectLinkUpdateOne { + dluo.mutation.ClearDeletedAt() + return dluo +} + +// SetName sets the "name" field. +func (dluo *DirectLinkUpdateOne) SetName(s string) *DirectLinkUpdateOne { + dluo.mutation.SetName(s) + return dluo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (dluo *DirectLinkUpdateOne) SetNillableName(s *string) *DirectLinkUpdateOne { + if s != nil { + dluo.SetName(*s) + } + return dluo +} + +// SetDownloads sets the "downloads" field. +func (dluo *DirectLinkUpdateOne) SetDownloads(i int) *DirectLinkUpdateOne { + dluo.mutation.ResetDownloads() + dluo.mutation.SetDownloads(i) + return dluo +} + +// SetNillableDownloads sets the "downloads" field if the given value is not nil. +func (dluo *DirectLinkUpdateOne) SetNillableDownloads(i *int) *DirectLinkUpdateOne { + if i != nil { + dluo.SetDownloads(*i) + } + return dluo +} + +// AddDownloads adds i to the "downloads" field. +func (dluo *DirectLinkUpdateOne) AddDownloads(i int) *DirectLinkUpdateOne { + dluo.mutation.AddDownloads(i) + return dluo +} + +// SetFileID sets the "file_id" field. +func (dluo *DirectLinkUpdateOne) SetFileID(i int) *DirectLinkUpdateOne { + dluo.mutation.SetFileID(i) + return dluo +} + +// SetNillableFileID sets the "file_id" field if the given value is not nil. +func (dluo *DirectLinkUpdateOne) SetNillableFileID(i *int) *DirectLinkUpdateOne { + if i != nil { + dluo.SetFileID(*i) + } + return dluo +} + +// SetSpeed sets the "speed" field. +func (dluo *DirectLinkUpdateOne) SetSpeed(i int) *DirectLinkUpdateOne { + dluo.mutation.ResetSpeed() + dluo.mutation.SetSpeed(i) + return dluo +} + +// SetNillableSpeed sets the "speed" field if the given value is not nil. +func (dluo *DirectLinkUpdateOne) SetNillableSpeed(i *int) *DirectLinkUpdateOne { + if i != nil { + dluo.SetSpeed(*i) + } + return dluo +} + +// AddSpeed adds i to the "speed" field. +func (dluo *DirectLinkUpdateOne) AddSpeed(i int) *DirectLinkUpdateOne { + dluo.mutation.AddSpeed(i) + return dluo +} + +// SetFile sets the "file" edge to the File entity. +func (dluo *DirectLinkUpdateOne) SetFile(f *File) *DirectLinkUpdateOne { + return dluo.SetFileID(f.ID) +} + +// Mutation returns the DirectLinkMutation object of the builder. +func (dluo *DirectLinkUpdateOne) Mutation() *DirectLinkMutation { + return dluo.mutation +} + +// ClearFile clears the "file" edge to the File entity. +func (dluo *DirectLinkUpdateOne) ClearFile() *DirectLinkUpdateOne { + dluo.mutation.ClearFile() + return dluo +} + +// Where appends a list predicates to the DirectLinkUpdate builder. +func (dluo *DirectLinkUpdateOne) Where(ps ...predicate.DirectLink) *DirectLinkUpdateOne { + dluo.mutation.Where(ps...) + return dluo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (dluo *DirectLinkUpdateOne) Select(field string, fields ...string) *DirectLinkUpdateOne { + dluo.fields = append([]string{field}, fields...) + return dluo +} + +// Save executes the query and returns the updated DirectLink entity. +func (dluo *DirectLinkUpdateOne) Save(ctx context.Context) (*DirectLink, error) { + if err := dluo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, dluo.sqlSave, dluo.mutation, dluo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (dluo *DirectLinkUpdateOne) SaveX(ctx context.Context) *DirectLink { + node, err := dluo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (dluo *DirectLinkUpdateOne) Exec(ctx context.Context) error { + _, err := dluo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (dluo *DirectLinkUpdateOne) ExecX(ctx context.Context) { + if err := dluo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (dluo *DirectLinkUpdateOne) defaults() error { + if _, ok := dluo.mutation.UpdatedAt(); !ok { + if directlink.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized directlink.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := directlink.UpdateDefaultUpdatedAt() + dluo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (dluo *DirectLinkUpdateOne) check() error { + if _, ok := dluo.mutation.FileID(); dluo.mutation.FileCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "DirectLink.file"`) + } + return nil +} + +func (dluo *DirectLinkUpdateOne) sqlSave(ctx context.Context) (_node *DirectLink, err error) { + if err := dluo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(directlink.Table, directlink.Columns, sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt)) + id, ok := dluo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "DirectLink.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := dluo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, directlink.FieldID) + for _, f := range fields { + if !directlink.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != directlink.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := dluo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := dluo.mutation.UpdatedAt(); ok { + _spec.SetField(directlink.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := dluo.mutation.DeletedAt(); ok { + _spec.SetField(directlink.FieldDeletedAt, field.TypeTime, value) + } + if dluo.mutation.DeletedAtCleared() { + _spec.ClearField(directlink.FieldDeletedAt, field.TypeTime) + } + if value, ok := dluo.mutation.Name(); ok { + _spec.SetField(directlink.FieldName, field.TypeString, value) + } + if value, ok := dluo.mutation.Downloads(); ok { + _spec.SetField(directlink.FieldDownloads, field.TypeInt, value) + } + if value, ok := dluo.mutation.AddedDownloads(); ok { + _spec.AddField(directlink.FieldDownloads, field.TypeInt, value) + } + if value, ok := dluo.mutation.Speed(); ok { + _spec.SetField(directlink.FieldSpeed, field.TypeInt, value) + } + if value, ok := dluo.mutation.AddedSpeed(); ok { + _spec.AddField(directlink.FieldSpeed, field.TypeInt, value) + } + if dluo.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: directlink.FileTable, + Columns: []string{directlink.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := dluo.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: directlink.FileTable, + Columns: []string{directlink.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &DirectLink{config: dluo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, dluo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{directlink.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + dluo.mutation.done = true + return _node, nil +} diff --git a/ent/ent.go b/ent/ent.go new file mode 100644 index 00000000..4026fe67 --- /dev/null +++ b/ent/ent.go @@ -0,0 +1,632 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// ent aliases to avoid import conflicts in user's code. +type ( + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc +) + +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + +// OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. +type OrderFunc func(*sql.Selector) + +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + davaccount.Table: davaccount.ValidColumn, + directlink.Table: directlink.ValidColumn, + entity.Table: entity.ValidColumn, + file.Table: file.ValidColumn, + group.Table: group.ValidColumn, + metadata.Table: metadata.ValidColumn, + node.Table: node.ValidColumn, + passkey.Table: passkey.ValidColumn, + setting.Table: setting.ValidColumn, + share.Table: share.ValidColumn, + storagepolicy.Table: storagepolicy.ValidColumn, + task.Table: task.ValidColumn, + user.Table: user.ValidColumn, + }) + }) + return columnCheck(table, column) +} + +// Asc applies the given fields in ASC order. +func Asc(fields ...string) func(*sql.Selector) { + return func(s *sql.Selector) { + for _, f := range fields { + if err := checkColumn(s.TableName(), f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + } + s.OrderBy(sql.Asc(s.C(f))) + } + } +} + +// Desc applies the given fields in DESC order. +func Desc(fields ...string) func(*sql.Selector) { + return func(s *sql.Selector) { + for _, f := range fields { + if err := checkColumn(s.TableName(), f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + } + s.OrderBy(sql.Desc(s.C(f))) + } + } +} + +// AggregateFunc applies an aggregation step on the group-by traversal/selector. +type AggregateFunc func(*sql.Selector) string + +// As is a pseudo aggregation function for renaming another other functions with custom names. For example: +// +// GroupBy(field1, field2). +// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). +// Scan(ctx, &v) +func As(fn AggregateFunc, end string) AggregateFunc { + return func(s *sql.Selector) string { + return sql.As(fn(s), end) + } +} + +// Count applies the "count" aggregation function on each group. +func Count() AggregateFunc { + return func(s *sql.Selector) string { + return sql.Count("*") + } +} + +// Max applies the "max" aggregation function on the given field of each group. +func Max(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Max(s.C(field)) + } +} + +// Mean applies the "mean" aggregation function on the given field of each group. +func Mean(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Avg(s.C(field)) + } +} + +// Min applies the "min" aggregation function on the given field of each group. +func Min(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Min(s.C(field)) + } +} + +// Sum applies the "sum" aggregation function on the given field of each group. +func Sum(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Sum(s.C(field)) + } +} + +// ValidationError returns when validating a field or edge fails. +type ValidationError struct { + Name string // Field or edge name. + err error +} + +// Error implements the error interface. +func (e *ValidationError) Error() string { + return e.err.Error() +} + +// Unwrap implements the errors.Wrapper interface. +func (e *ValidationError) Unwrap() error { + return e.err +} + +// IsValidationError returns a boolean indicating whether the error is a validation error. +func IsValidationError(err error) bool { + if err == nil { + return false + } + var e *ValidationError + return errors.As(err, &e) +} + +// NotFoundError returns when trying to fetch a specific entity and it was not found in the database. +type NotFoundError struct { + label string +} + +// Error implements the error interface. +func (e *NotFoundError) Error() string { + return "ent: " + e.label + " not found" +} + +// IsNotFound returns a boolean indicating whether the error is a not found error. +func IsNotFound(err error) bool { + if err == nil { + return false + } + var e *NotFoundError + return errors.As(err, &e) +} + +// MaskNotFound masks not found error. +func MaskNotFound(err error) error { + if IsNotFound(err) { + return nil + } + return err +} + +// NotSingularError returns when trying to fetch a singular entity and more then one was found in the database. +type NotSingularError struct { + label string +} + +// Error implements the error interface. +func (e *NotSingularError) Error() string { + return "ent: " + e.label + " not singular" +} + +// IsNotSingular returns a boolean indicating whether the error is a not singular error. +func IsNotSingular(err error) bool { + if err == nil { + return false + } + var e *NotSingularError + return errors.As(err, &e) +} + +// NotLoadedError returns when trying to get a node that was not loaded by the query. +type NotLoadedError struct { + edge string +} + +// Error implements the error interface. +func (e *NotLoadedError) Error() string { + return "ent: " + e.edge + " edge was not loaded" +} + +// IsNotLoaded returns a boolean indicating whether the error is a not loaded error. +func IsNotLoaded(err error) bool { + if err == nil { + return false + } + var e *NotLoadedError + return errors.As(err, &e) +} + +// ConstraintError returns when trying to create/update one or more entities and +// one or more of their constraints failed. For example, violation of edge or +// field uniqueness. +type ConstraintError struct { + msg string + wrap error +} + +// Error implements the error interface. +func (e ConstraintError) Error() string { + return "ent: constraint failed: " + e.msg +} + +// Unwrap implements the errors.Wrapper interface. +func (e *ConstraintError) Unwrap() error { + return e.wrap +} + +// IsConstraintError returns a boolean indicating whether the error is a constraint failure. +func IsConstraintError(err error) bool { + if err == nil { + return false + } + var e *ConstraintError + return errors.As(err, &e) +} + +// selector embedded by the different Select/GroupBy builders. +type selector struct { + label string + flds *[]string + fns []AggregateFunc + scan func(context.Context, any) error +} + +// ScanX is like Scan, but panics if an error occurs. +func (s *selector) ScanX(ctx context.Context, v any) { + if err := s.scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from a selector. It is only allowed when selecting one field. +func (s *selector) Strings(ctx context.Context) ([]string, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Strings is not achievable when selecting more than 1 field") + } + var v []string + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (s *selector) StringsX(ctx context.Context) []string { + v, err := s.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// String returns a single string from a selector. It is only allowed when selecting one field. +func (s *selector) String(ctx context.Context) (_ string, err error) { + var v []string + if v, err = s.Strings(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Strings returned %d results when one was expected", len(v)) + } + return +} + +// StringX is like String, but panics if an error occurs. +func (s *selector) StringX(ctx context.Context) string { + v, err := s.String(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from a selector. It is only allowed when selecting one field. +func (s *selector) Ints(ctx context.Context) ([]int, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Ints is not achievable when selecting more than 1 field") + } + var v []int + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (s *selector) IntsX(ctx context.Context) []int { + v, err := s.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Int returns a single int from a selector. It is only allowed when selecting one field. +func (s *selector) Int(ctx context.Context) (_ int, err error) { + var v []int + if v, err = s.Ints(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Ints returned %d results when one was expected", len(v)) + } + return +} + +// IntX is like Int, but panics if an error occurs. +func (s *selector) IntX(ctx context.Context) int { + v, err := s.Int(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from a selector. It is only allowed when selecting one field. +func (s *selector) Float64s(ctx context.Context) ([]float64, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Float64s is not achievable when selecting more than 1 field") + } + var v []float64 + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (s *selector) Float64sX(ctx context.Context) []float64 { + v, err := s.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64 returns a single float64 from a selector. It is only allowed when selecting one field. +func (s *selector) Float64(ctx context.Context) (_ float64, err error) { + var v []float64 + if v, err = s.Float64s(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Float64s returned %d results when one was expected", len(v)) + } + return +} + +// Float64X is like Float64, but panics if an error occurs. +func (s *selector) Float64X(ctx context.Context) float64 { + v, err := s.Float64(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from a selector. It is only allowed when selecting one field. +func (s *selector) Bools(ctx context.Context) ([]bool, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Bools is not achievable when selecting more than 1 field") + } + var v []bool + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (s *selector) BoolsX(ctx context.Context) []bool { + v, err := s.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bool returns a single bool from a selector. It is only allowed when selecting one field. +func (s *selector) Bool(ctx context.Context) (_ bool, err error) { + var v []bool + if v, err = s.Bools(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Bools returned %d results when one was expected", len(v)) + } + return +} + +// BoolX is like Bool, but panics if an error occurs. +func (s *selector) BoolX(ctx context.Context) bool { + v, err := s.Bool(ctx) + if err != nil { + panic(err) + } + return v +} + +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/ent/entc.go b/ent/entc.go new file mode 100644 index 00000000..4f2ab1f0 --- /dev/null +++ b/ent/entc.go @@ -0,0 +1,29 @@ +//go:build ignore + +package main + +import ( + "log" + + "entgo.io/ent/entc" + "entgo.io/ent/entc/gen" +) + +func main() { + if err := entc.Generate("./schema", &gen.Config{ + Features: []gen.Feature{ + gen.FeatureIntercept, + gen.FeatureSnapshot, + gen.FeatureUpsert, + gen.FeatureUpsert, + gen.FeatureExecQuery, + }, + Templates: []*gen.Template{ + gen.MustParse(gen.NewTemplate("edge_helper").ParseFiles("templates/edgehelper.tmpl")), + gen.MustParse(gen.NewTemplate("mutation_helper").ParseFiles("templates/mutationhelper.tmpl")), + gen.MustParse(gen.NewTemplate("create_helper").ParseFiles("templates/createhelper.tmpl")), + }, + }); err != nil { + log.Fatal("running ent codegen:", err) + } +} diff --git a/ent/entity.go b/ent/entity.go new file mode 100644 index 00000000..6e57c30b --- /dev/null +++ b/ent/entity.go @@ -0,0 +1,317 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// Entity is the model entity for the Entity schema. +type Entity struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Type holds the value of the "type" field. + Type int `json:"type,omitempty"` + // Source holds the value of the "source" field. + Source string `json:"source,omitempty"` + // Size holds the value of the "size" field. + Size int64 `json:"size,omitempty"` + // ReferenceCount holds the value of the "reference_count" field. + ReferenceCount int `json:"reference_count,omitempty"` + // StoragePolicyEntities holds the value of the "storage_policy_entities" field. + StoragePolicyEntities int `json:"storage_policy_entities,omitempty"` + // CreatedBy holds the value of the "created_by" field. + CreatedBy int `json:"created_by,omitempty"` + // UploadSessionID holds the value of the "upload_session_id" field. + UploadSessionID *uuid.UUID `json:"upload_session_id,omitempty"` + // RecycleOptions holds the value of the "recycle_options" field. + RecycleOptions *types.EntityRecycleOption `json:"recycle_options,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the EntityQuery when eager-loading is set. + Edges EntityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// EntityEdges holds the relations/edges for other nodes in the graph. +type EntityEdges struct { + // File holds the value of the file edge. + File []*File `json:"file,omitempty"` + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // StoragePolicy holds the value of the storage_policy edge. + StoragePolicy *StoragePolicy `json:"storage_policy,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// FileOrErr returns the File value or an error if the edge +// was not loaded in eager-loading. +func (e EntityEdges) FileOrErr() ([]*File, error) { + if e.loadedTypes[0] { + return e.File, nil + } + return nil, &NotLoadedError{edge: "file"} +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e EntityEdges) UserOrErr() (*User, error) { + if e.loadedTypes[1] { + if e.User == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.User, nil + } + return nil, &NotLoadedError{edge: "user"} +} + +// StoragePolicyOrErr returns the StoragePolicy value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e EntityEdges) StoragePolicyOrErr() (*StoragePolicy, error) { + if e.loadedTypes[2] { + if e.StoragePolicy == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: storagepolicy.Label} + } + return e.StoragePolicy, nil + } + return nil, &NotLoadedError{edge: "storage_policy"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Entity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case entity.FieldUploadSessionID: + values[i] = &sql.NullScanner{S: new(uuid.UUID)} + case entity.FieldRecycleOptions: + values[i] = new([]byte) + case entity.FieldID, entity.FieldType, entity.FieldSize, entity.FieldReferenceCount, entity.FieldStoragePolicyEntities, entity.FieldCreatedBy: + values[i] = new(sql.NullInt64) + case entity.FieldSource: + values[i] = new(sql.NullString) + case entity.FieldCreatedAt, entity.FieldUpdatedAt, entity.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Entity fields. +func (e *Entity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case entity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + e.ID = int(value.Int64) + case entity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + e.CreatedAt = value.Time + } + case entity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + e.UpdatedAt = value.Time + } + case entity.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + e.DeletedAt = new(time.Time) + *e.DeletedAt = value.Time + } + case entity.FieldType: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + e.Type = int(value.Int64) + } + case entity.FieldSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field source", values[i]) + } else if value.Valid { + e.Source = value.String + } + case entity.FieldSize: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field size", values[i]) + } else if value.Valid { + e.Size = value.Int64 + } + case entity.FieldReferenceCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field reference_count", values[i]) + } else if value.Valid { + e.ReferenceCount = int(value.Int64) + } + case entity.FieldStoragePolicyEntities: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field storage_policy_entities", values[i]) + } else if value.Valid { + e.StoragePolicyEntities = int(value.Int64) + } + case entity.FieldCreatedBy: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field created_by", values[i]) + } else if value.Valid { + e.CreatedBy = int(value.Int64) + } + case entity.FieldUploadSessionID: + if value, ok := values[i].(*sql.NullScanner); !ok { + return fmt.Errorf("unexpected type %T for field upload_session_id", values[i]) + } else if value.Valid { + e.UploadSessionID = new(uuid.UUID) + *e.UploadSessionID = *value.S.(*uuid.UUID) + } + case entity.FieldRecycleOptions: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field recycle_options", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &e.RecycleOptions); err != nil { + return fmt.Errorf("unmarshal field recycle_options: %w", err) + } + } + default: + e.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Entity. +// This includes values selected through modifiers, order, etc. +func (e *Entity) Value(name string) (ent.Value, error) { + return e.selectValues.Get(name) +} + +// QueryFile queries the "file" edge of the Entity entity. +func (e *Entity) QueryFile() *FileQuery { + return NewEntityClient(e.config).QueryFile(e) +} + +// QueryUser queries the "user" edge of the Entity entity. +func (e *Entity) QueryUser() *UserQuery { + return NewEntityClient(e.config).QueryUser(e) +} + +// QueryStoragePolicy queries the "storage_policy" edge of the Entity entity. +func (e *Entity) QueryStoragePolicy() *StoragePolicyQuery { + return NewEntityClient(e.config).QueryStoragePolicy(e) +} + +// Update returns a builder for updating this Entity. +// Note that you need to call Entity.Unwrap() before calling this method if this Entity +// was returned from a transaction, and the transaction was committed or rolled back. +func (e *Entity) Update() *EntityUpdateOne { + return NewEntityClient(e.config).UpdateOne(e) +} + +// Unwrap unwraps the Entity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (e *Entity) Unwrap() *Entity { + _tx, ok := e.config.driver.(*txDriver) + if !ok { + panic("ent: Entity is not a transactional entity") + } + e.config.driver = _tx.drv + return e +} + +// String implements the fmt.Stringer. +func (e *Entity) String() string { + var builder strings.Builder + builder.WriteString("Entity(") + builder.WriteString(fmt.Sprintf("id=%v, ", e.ID)) + builder.WriteString("created_at=") + builder.WriteString(e.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(e.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := e.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("type=") + builder.WriteString(fmt.Sprintf("%v", e.Type)) + builder.WriteString(", ") + builder.WriteString("source=") + builder.WriteString(e.Source) + builder.WriteString(", ") + builder.WriteString("size=") + builder.WriteString(fmt.Sprintf("%v", e.Size)) + builder.WriteString(", ") + builder.WriteString("reference_count=") + builder.WriteString(fmt.Sprintf("%v", e.ReferenceCount)) + builder.WriteString(", ") + builder.WriteString("storage_policy_entities=") + builder.WriteString(fmt.Sprintf("%v", e.StoragePolicyEntities)) + builder.WriteString(", ") + builder.WriteString("created_by=") + builder.WriteString(fmt.Sprintf("%v", e.CreatedBy)) + builder.WriteString(", ") + if v := e.UploadSessionID; v != nil { + builder.WriteString("upload_session_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("recycle_options=") + builder.WriteString(fmt.Sprintf("%v", e.RecycleOptions)) + builder.WriteByte(')') + return builder.String() +} + +// SetFile manually set the edge as loaded state. +func (e *Entity) SetFile(v []*File) { + e.Edges.File = v + e.Edges.loadedTypes[0] = true +} + +// SetUser manually set the edge as loaded state. +func (e *Entity) SetUser(v *User) { + e.Edges.User = v + e.Edges.loadedTypes[1] = true +} + +// SetStoragePolicy manually set the edge as loaded state. +func (e *Entity) SetStoragePolicy(v *StoragePolicy) { + e.Edges.StoragePolicy = v + e.Edges.loadedTypes[2] = true +} + +// Entities is a parsable slice of Entity. +type Entities []*Entity diff --git a/ent/entity/entity.go b/ent/entity/entity.go new file mode 100644 index 00000000..ed8e402d --- /dev/null +++ b/ent/entity/entity.go @@ -0,0 +1,224 @@ +// Code generated by ent, DO NOT EDIT. + +package entity + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the entity type in the database. + Label = "entity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldSource holds the string denoting the source field in the database. + FieldSource = "source" + // FieldSize holds the string denoting the size field in the database. + FieldSize = "size" + // FieldReferenceCount holds the string denoting the reference_count field in the database. + FieldReferenceCount = "reference_count" + // FieldStoragePolicyEntities holds the string denoting the storage_policy_entities field in the database. + FieldStoragePolicyEntities = "storage_policy_entities" + // FieldCreatedBy holds the string denoting the created_by field in the database. + FieldCreatedBy = "created_by" + // FieldUploadSessionID holds the string denoting the upload_session_id field in the database. + FieldUploadSessionID = "upload_session_id" + // FieldRecycleOptions holds the string denoting the recycle_options field in the database. + FieldRecycleOptions = "recycle_options" + // EdgeFile holds the string denoting the file edge name in mutations. + EdgeFile = "file" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeStoragePolicy holds the string denoting the storage_policy edge name in mutations. + EdgeStoragePolicy = "storage_policy" + // Table holds the table name of the entity in the database. + Table = "entities" + // FileTable is the table that holds the file relation/edge. The primary key declared below. + FileTable = "file_entities" + // FileInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FileInverseTable = "files" + // UserTable is the table that holds the user relation/edge. + UserTable = "entities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "created_by" + // StoragePolicyTable is the table that holds the storage_policy relation/edge. + StoragePolicyTable = "entities" + // StoragePolicyInverseTable is the table name for the StoragePolicy entity. + // It exists in this package in order to avoid circular dependency with the "storagepolicy" package. + StoragePolicyInverseTable = "storage_policies" + // StoragePolicyColumn is the table column denoting the storage_policy relation/edge. + StoragePolicyColumn = "storage_policy_entities" +) + +// Columns holds all SQL columns for entity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldType, + FieldSource, + FieldSize, + FieldReferenceCount, + FieldStoragePolicyEntities, + FieldCreatedBy, + FieldUploadSessionID, + FieldRecycleOptions, +} + +var ( + // FilePrimaryKey and FileColumn2 are the table columns denoting the + // primary key for the file relation (M2M). + FilePrimaryKey = []string{"file_id", "entity_id"} +) + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultReferenceCount holds the default value on creation for the "reference_count" field. + DefaultReferenceCount int +) + +// OrderOption defines the ordering options for the Entity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// BySource orders the results by the source field. +func BySource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSource, opts...).ToFunc() +} + +// BySize orders the results by the size field. +func BySize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSize, opts...).ToFunc() +} + +// ByReferenceCount orders the results by the reference_count field. +func ByReferenceCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldReferenceCount, opts...).ToFunc() +} + +// ByStoragePolicyEntities orders the results by the storage_policy_entities field. +func ByStoragePolicyEntities(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoragePolicyEntities, opts...).ToFunc() +} + +// ByCreatedBy orders the results by the created_by field. +func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedBy, opts...).ToFunc() +} + +// ByUploadSessionID orders the results by the upload_session_id field. +func ByUploadSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUploadSessionID, opts...).ToFunc() +} + +// ByFileCount orders the results by file count. +func ByFileCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newFileStep(), opts...) + } +} + +// ByFile orders the results by file terms. +func ByFile(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFileStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByStoragePolicyField orders the results by storage_policy field. +func ByStoragePolicyField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStoragePolicyStep(), sql.OrderByField(field, opts...)) + } +} +func newFileStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FileInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, FileTable, FilePrimaryKey...), + ) +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newStoragePolicyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StoragePolicyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePolicyTable, StoragePolicyColumn), + ) +} diff --git a/ent/entity/where.go b/ent/entity/where.go new file mode 100644 index 00000000..90fbec37 --- /dev/null +++ b/ent/entity/where.go @@ -0,0 +1,616 @@ +// Code generated by ent, DO NOT EDIT. + +package entity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/gofrs/uuid" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Type applies equality check predicate on the "type" field. It's identical to TypeEQ. +func Type(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldType, v)) +} + +// Source applies equality check predicate on the "source" field. It's identical to SourceEQ. +func Source(v string) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldSource, v)) +} + +// Size applies equality check predicate on the "size" field. It's identical to SizeEQ. +func Size(v int64) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldSize, v)) +} + +// ReferenceCount applies equality check predicate on the "reference_count" field. It's identical to ReferenceCountEQ. +func ReferenceCount(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldReferenceCount, v)) +} + +// StoragePolicyEntities applies equality check predicate on the "storage_policy_entities" field. It's identical to StoragePolicyEntitiesEQ. +func StoragePolicyEntities(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldStoragePolicyEntities, v)) +} + +// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ. +func CreatedBy(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldCreatedBy, v)) +} + +// UploadSessionID applies equality check predicate on the "upload_session_id" field. It's identical to UploadSessionIDEQ. +func UploadSessionID(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldUploadSessionID, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Entity { + return predicate.Entity(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Entity { + return predicate.Entity(sql.FieldNotNull(FieldDeletedAt)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldType, vs...)) +} + +// TypeGT applies the GT predicate on the "type" field. +func TypeGT(v int) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldType, v)) +} + +// TypeGTE applies the GTE predicate on the "type" field. +func TypeGTE(v int) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldType, v)) +} + +// TypeLT applies the LT predicate on the "type" field. +func TypeLT(v int) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldType, v)) +} + +// TypeLTE applies the LTE predicate on the "type" field. +func TypeLTE(v int) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldType, v)) +} + +// SourceEQ applies the EQ predicate on the "source" field. +func SourceEQ(v string) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldSource, v)) +} + +// SourceNEQ applies the NEQ predicate on the "source" field. +func SourceNEQ(v string) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldSource, v)) +} + +// SourceIn applies the In predicate on the "source" field. +func SourceIn(vs ...string) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldSource, vs...)) +} + +// SourceNotIn applies the NotIn predicate on the "source" field. +func SourceNotIn(vs ...string) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldSource, vs...)) +} + +// SourceGT applies the GT predicate on the "source" field. +func SourceGT(v string) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldSource, v)) +} + +// SourceGTE applies the GTE predicate on the "source" field. +func SourceGTE(v string) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldSource, v)) +} + +// SourceLT applies the LT predicate on the "source" field. +func SourceLT(v string) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldSource, v)) +} + +// SourceLTE applies the LTE predicate on the "source" field. +func SourceLTE(v string) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldSource, v)) +} + +// SourceContains applies the Contains predicate on the "source" field. +func SourceContains(v string) predicate.Entity { + return predicate.Entity(sql.FieldContains(FieldSource, v)) +} + +// SourceHasPrefix applies the HasPrefix predicate on the "source" field. +func SourceHasPrefix(v string) predicate.Entity { + return predicate.Entity(sql.FieldHasPrefix(FieldSource, v)) +} + +// SourceHasSuffix applies the HasSuffix predicate on the "source" field. +func SourceHasSuffix(v string) predicate.Entity { + return predicate.Entity(sql.FieldHasSuffix(FieldSource, v)) +} + +// SourceEqualFold applies the EqualFold predicate on the "source" field. +func SourceEqualFold(v string) predicate.Entity { + return predicate.Entity(sql.FieldEqualFold(FieldSource, v)) +} + +// SourceContainsFold applies the ContainsFold predicate on the "source" field. +func SourceContainsFold(v string) predicate.Entity { + return predicate.Entity(sql.FieldContainsFold(FieldSource, v)) +} + +// SizeEQ applies the EQ predicate on the "size" field. +func SizeEQ(v int64) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldSize, v)) +} + +// SizeNEQ applies the NEQ predicate on the "size" field. +func SizeNEQ(v int64) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldSize, v)) +} + +// SizeIn applies the In predicate on the "size" field. +func SizeIn(vs ...int64) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldSize, vs...)) +} + +// SizeNotIn applies the NotIn predicate on the "size" field. +func SizeNotIn(vs ...int64) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldSize, vs...)) +} + +// SizeGT applies the GT predicate on the "size" field. +func SizeGT(v int64) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldSize, v)) +} + +// SizeGTE applies the GTE predicate on the "size" field. +func SizeGTE(v int64) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldSize, v)) +} + +// SizeLT applies the LT predicate on the "size" field. +func SizeLT(v int64) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldSize, v)) +} + +// SizeLTE applies the LTE predicate on the "size" field. +func SizeLTE(v int64) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldSize, v)) +} + +// ReferenceCountEQ applies the EQ predicate on the "reference_count" field. +func ReferenceCountEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldReferenceCount, v)) +} + +// ReferenceCountNEQ applies the NEQ predicate on the "reference_count" field. +func ReferenceCountNEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldReferenceCount, v)) +} + +// ReferenceCountIn applies the In predicate on the "reference_count" field. +func ReferenceCountIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldReferenceCount, vs...)) +} + +// ReferenceCountNotIn applies the NotIn predicate on the "reference_count" field. +func ReferenceCountNotIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldReferenceCount, vs...)) +} + +// ReferenceCountGT applies the GT predicate on the "reference_count" field. +func ReferenceCountGT(v int) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldReferenceCount, v)) +} + +// ReferenceCountGTE applies the GTE predicate on the "reference_count" field. +func ReferenceCountGTE(v int) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldReferenceCount, v)) +} + +// ReferenceCountLT applies the LT predicate on the "reference_count" field. +func ReferenceCountLT(v int) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldReferenceCount, v)) +} + +// ReferenceCountLTE applies the LTE predicate on the "reference_count" field. +func ReferenceCountLTE(v int) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldReferenceCount, v)) +} + +// StoragePolicyEntitiesEQ applies the EQ predicate on the "storage_policy_entities" field. +func StoragePolicyEntitiesEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldStoragePolicyEntities, v)) +} + +// StoragePolicyEntitiesNEQ applies the NEQ predicate on the "storage_policy_entities" field. +func StoragePolicyEntitiesNEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldStoragePolicyEntities, v)) +} + +// StoragePolicyEntitiesIn applies the In predicate on the "storage_policy_entities" field. +func StoragePolicyEntitiesIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldStoragePolicyEntities, vs...)) +} + +// StoragePolicyEntitiesNotIn applies the NotIn predicate on the "storage_policy_entities" field. +func StoragePolicyEntitiesNotIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldStoragePolicyEntities, vs...)) +} + +// CreatedByEQ applies the EQ predicate on the "created_by" field. +func CreatedByEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldCreatedBy, v)) +} + +// CreatedByNEQ applies the NEQ predicate on the "created_by" field. +func CreatedByNEQ(v int) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldCreatedBy, v)) +} + +// CreatedByIn applies the In predicate on the "created_by" field. +func CreatedByIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldCreatedBy, vs...)) +} + +// CreatedByNotIn applies the NotIn predicate on the "created_by" field. +func CreatedByNotIn(vs ...int) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldCreatedBy, vs...)) +} + +// CreatedByIsNil applies the IsNil predicate on the "created_by" field. +func CreatedByIsNil() predicate.Entity { + return predicate.Entity(sql.FieldIsNull(FieldCreatedBy)) +} + +// CreatedByNotNil applies the NotNil predicate on the "created_by" field. +func CreatedByNotNil() predicate.Entity { + return predicate.Entity(sql.FieldNotNull(FieldCreatedBy)) +} + +// UploadSessionIDEQ applies the EQ predicate on the "upload_session_id" field. +func UploadSessionIDEQ(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldEQ(FieldUploadSessionID, v)) +} + +// UploadSessionIDNEQ applies the NEQ predicate on the "upload_session_id" field. +func UploadSessionIDNEQ(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldNEQ(FieldUploadSessionID, v)) +} + +// UploadSessionIDIn applies the In predicate on the "upload_session_id" field. +func UploadSessionIDIn(vs ...uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldIn(FieldUploadSessionID, vs...)) +} + +// UploadSessionIDNotIn applies the NotIn predicate on the "upload_session_id" field. +func UploadSessionIDNotIn(vs ...uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldNotIn(FieldUploadSessionID, vs...)) +} + +// UploadSessionIDGT applies the GT predicate on the "upload_session_id" field. +func UploadSessionIDGT(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldGT(FieldUploadSessionID, v)) +} + +// UploadSessionIDGTE applies the GTE predicate on the "upload_session_id" field. +func UploadSessionIDGTE(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldGTE(FieldUploadSessionID, v)) +} + +// UploadSessionIDLT applies the LT predicate on the "upload_session_id" field. +func UploadSessionIDLT(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldLT(FieldUploadSessionID, v)) +} + +// UploadSessionIDLTE applies the LTE predicate on the "upload_session_id" field. +func UploadSessionIDLTE(v uuid.UUID) predicate.Entity { + return predicate.Entity(sql.FieldLTE(FieldUploadSessionID, v)) +} + +// UploadSessionIDIsNil applies the IsNil predicate on the "upload_session_id" field. +func UploadSessionIDIsNil() predicate.Entity { + return predicate.Entity(sql.FieldIsNull(FieldUploadSessionID)) +} + +// UploadSessionIDNotNil applies the NotNil predicate on the "upload_session_id" field. +func UploadSessionIDNotNil() predicate.Entity { + return predicate.Entity(sql.FieldNotNull(FieldUploadSessionID)) +} + +// RecycleOptionsIsNil applies the IsNil predicate on the "recycle_options" field. +func RecycleOptionsIsNil() predicate.Entity { + return predicate.Entity(sql.FieldIsNull(FieldRecycleOptions)) +} + +// RecycleOptionsNotNil applies the NotNil predicate on the "recycle_options" field. +func RecycleOptionsNotNil() predicate.Entity { + return predicate.Entity(sql.FieldNotNull(FieldRecycleOptions)) +} + +// HasFile applies the HasEdge predicate on the "file" edge. +func HasFile() predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, FileTable, FilePrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFileWith applies the HasEdge predicate on the "file" edge with a given conditions (other predicates). +func HasFileWith(preds ...predicate.File) predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := newFileStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasStoragePolicy applies the HasEdge predicate on the "storage_policy" edge. +func HasStoragePolicy() predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePolicyTable, StoragePolicyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasStoragePolicyWith applies the HasEdge predicate on the "storage_policy" edge with a given conditions (other predicates). +func HasStoragePolicyWith(preds ...predicate.StoragePolicy) predicate.Entity { + return predicate.Entity(func(s *sql.Selector) { + step := newStoragePolicyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Entity) predicate.Entity { + return predicate.Entity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Entity) predicate.Entity { + return predicate.Entity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Entity) predicate.Entity { + return predicate.Entity(sql.NotPredicates(p)) +} diff --git a/ent/entity_create.go b/ent/entity_create.go new file mode 100644 index 00000000..48768bdd --- /dev/null +++ b/ent/entity_create.go @@ -0,0 +1,1267 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// EntityCreate is the builder for creating a Entity entity. +type EntityCreate struct { + config + mutation *EntityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (ec *EntityCreate) SetCreatedAt(t time.Time) *EntityCreate { + ec.mutation.SetCreatedAt(t) + return ec +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (ec *EntityCreate) SetNillableCreatedAt(t *time.Time) *EntityCreate { + if t != nil { + ec.SetCreatedAt(*t) + } + return ec +} + +// SetUpdatedAt sets the "updated_at" field. +func (ec *EntityCreate) SetUpdatedAt(t time.Time) *EntityCreate { + ec.mutation.SetUpdatedAt(t) + return ec +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (ec *EntityCreate) SetNillableUpdatedAt(t *time.Time) *EntityCreate { + if t != nil { + ec.SetUpdatedAt(*t) + } + return ec +} + +// SetDeletedAt sets the "deleted_at" field. +func (ec *EntityCreate) SetDeletedAt(t time.Time) *EntityCreate { + ec.mutation.SetDeletedAt(t) + return ec +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (ec *EntityCreate) SetNillableDeletedAt(t *time.Time) *EntityCreate { + if t != nil { + ec.SetDeletedAt(*t) + } + return ec +} + +// SetType sets the "type" field. +func (ec *EntityCreate) SetType(i int) *EntityCreate { + ec.mutation.SetType(i) + return ec +} + +// SetSource sets the "source" field. +func (ec *EntityCreate) SetSource(s string) *EntityCreate { + ec.mutation.SetSource(s) + return ec +} + +// SetSize sets the "size" field. +func (ec *EntityCreate) SetSize(i int64) *EntityCreate { + ec.mutation.SetSize(i) + return ec +} + +// SetReferenceCount sets the "reference_count" field. +func (ec *EntityCreate) SetReferenceCount(i int) *EntityCreate { + ec.mutation.SetReferenceCount(i) + return ec +} + +// SetNillableReferenceCount sets the "reference_count" field if the given value is not nil. +func (ec *EntityCreate) SetNillableReferenceCount(i *int) *EntityCreate { + if i != nil { + ec.SetReferenceCount(*i) + } + return ec +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (ec *EntityCreate) SetStoragePolicyEntities(i int) *EntityCreate { + ec.mutation.SetStoragePolicyEntities(i) + return ec +} + +// SetCreatedBy sets the "created_by" field. +func (ec *EntityCreate) SetCreatedBy(i int) *EntityCreate { + ec.mutation.SetCreatedBy(i) + return ec +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (ec *EntityCreate) SetNillableCreatedBy(i *int) *EntityCreate { + if i != nil { + ec.SetCreatedBy(*i) + } + return ec +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (ec *EntityCreate) SetUploadSessionID(u uuid.UUID) *EntityCreate { + ec.mutation.SetUploadSessionID(u) + return ec +} + +// SetNillableUploadSessionID sets the "upload_session_id" field if the given value is not nil. +func (ec *EntityCreate) SetNillableUploadSessionID(u *uuid.UUID) *EntityCreate { + if u != nil { + ec.SetUploadSessionID(*u) + } + return ec +} + +// SetRecycleOptions sets the "recycle_options" field. +func (ec *EntityCreate) SetRecycleOptions(tro *types.EntityRecycleOption) *EntityCreate { + ec.mutation.SetRecycleOptions(tro) + return ec +} + +// AddFileIDs adds the "file" edge to the File entity by IDs. +func (ec *EntityCreate) AddFileIDs(ids ...int) *EntityCreate { + ec.mutation.AddFileIDs(ids...) + return ec +} + +// AddFile adds the "file" edges to the File entity. +func (ec *EntityCreate) AddFile(f ...*File) *EntityCreate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return ec.AddFileIDs(ids...) +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (ec *EntityCreate) SetUserID(id int) *EntityCreate { + ec.mutation.SetUserID(id) + return ec +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (ec *EntityCreate) SetNillableUserID(id *int) *EntityCreate { + if id != nil { + ec = ec.SetUserID(*id) + } + return ec +} + +// SetUser sets the "user" edge to the User entity. +func (ec *EntityCreate) SetUser(u *User) *EntityCreate { + return ec.SetUserID(u.ID) +} + +// SetStoragePolicyID sets the "storage_policy" edge to the StoragePolicy entity by ID. +func (ec *EntityCreate) SetStoragePolicyID(id int) *EntityCreate { + ec.mutation.SetStoragePolicyID(id) + return ec +} + +// SetStoragePolicy sets the "storage_policy" edge to the StoragePolicy entity. +func (ec *EntityCreate) SetStoragePolicy(s *StoragePolicy) *EntityCreate { + return ec.SetStoragePolicyID(s.ID) +} + +// Mutation returns the EntityMutation object of the builder. +func (ec *EntityCreate) Mutation() *EntityMutation { + return ec.mutation +} + +// Save creates the Entity in the database. +func (ec *EntityCreate) Save(ctx context.Context) (*Entity, error) { + if err := ec.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, ec.sqlSave, ec.mutation, ec.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (ec *EntityCreate) SaveX(ctx context.Context) *Entity { + v, err := ec.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (ec *EntityCreate) Exec(ctx context.Context) error { + _, err := ec.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ec *EntityCreate) ExecX(ctx context.Context) { + if err := ec.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (ec *EntityCreate) defaults() error { + if _, ok := ec.mutation.CreatedAt(); !ok { + if entity.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized entity.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := entity.DefaultCreatedAt() + ec.mutation.SetCreatedAt(v) + } + if _, ok := ec.mutation.UpdatedAt(); !ok { + if entity.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized entity.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := entity.DefaultUpdatedAt() + ec.mutation.SetUpdatedAt(v) + } + if _, ok := ec.mutation.ReferenceCount(); !ok { + v := entity.DefaultReferenceCount + ec.mutation.SetReferenceCount(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (ec *EntityCreate) check() error { + if _, ok := ec.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Entity.created_at"`)} + } + if _, ok := ec.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Entity.updated_at"`)} + } + if _, ok := ec.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "Entity.type"`)} + } + if _, ok := ec.mutation.Source(); !ok { + return &ValidationError{Name: "source", err: errors.New(`ent: missing required field "Entity.source"`)} + } + if _, ok := ec.mutation.Size(); !ok { + return &ValidationError{Name: "size", err: errors.New(`ent: missing required field "Entity.size"`)} + } + if _, ok := ec.mutation.ReferenceCount(); !ok { + return &ValidationError{Name: "reference_count", err: errors.New(`ent: missing required field "Entity.reference_count"`)} + } + if _, ok := ec.mutation.StoragePolicyEntities(); !ok { + return &ValidationError{Name: "storage_policy_entities", err: errors.New(`ent: missing required field "Entity.storage_policy_entities"`)} + } + if _, ok := ec.mutation.StoragePolicyID(); !ok { + return &ValidationError{Name: "storage_policy", err: errors.New(`ent: missing required edge "Entity.storage_policy"`)} + } + return nil +} + +func (ec *EntityCreate) sqlSave(ctx context.Context) (*Entity, error) { + if err := ec.check(); err != nil { + return nil, err + } + _node, _spec := ec.createSpec() + if err := sqlgraph.CreateNode(ctx, ec.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + ec.mutation.id = &_node.ID + ec.mutation.done = true + return _node, nil +} + +func (ec *EntityCreate) createSpec() (*Entity, *sqlgraph.CreateSpec) { + var ( + _node = &Entity{config: ec.config} + _spec = sqlgraph.NewCreateSpec(entity.Table, sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt)) + ) + + if id, ok := ec.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = ec.conflict + if value, ok := ec.mutation.CreatedAt(); ok { + _spec.SetField(entity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := ec.mutation.UpdatedAt(); ok { + _spec.SetField(entity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := ec.mutation.DeletedAt(); ok { + _spec.SetField(entity.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := ec.mutation.GetType(); ok { + _spec.SetField(entity.FieldType, field.TypeInt, value) + _node.Type = value + } + if value, ok := ec.mutation.Source(); ok { + _spec.SetField(entity.FieldSource, field.TypeString, value) + _node.Source = value + } + if value, ok := ec.mutation.Size(); ok { + _spec.SetField(entity.FieldSize, field.TypeInt64, value) + _node.Size = value + } + if value, ok := ec.mutation.ReferenceCount(); ok { + _spec.SetField(entity.FieldReferenceCount, field.TypeInt, value) + _node.ReferenceCount = value + } + if value, ok := ec.mutation.UploadSessionID(); ok { + _spec.SetField(entity.FieldUploadSessionID, field.TypeUUID, value) + _node.UploadSessionID = &value + } + if value, ok := ec.mutation.RecycleOptions(); ok { + _spec.SetField(entity.FieldRecycleOptions, field.TypeJSON, value) + _node.RecycleOptions = value + } + if nodes := ec.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := ec.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.UserTable, + Columns: []string{entity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.CreatedBy = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := ec.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.StoragePolicyTable, + Columns: []string{entity.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.StoragePolicyEntities = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Entity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.EntityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (ec *EntityCreate) OnConflict(opts ...sql.ConflictOption) *EntityUpsertOne { + ec.conflict = opts + return &EntityUpsertOne{ + create: ec, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (ec *EntityCreate) OnConflictColumns(columns ...string) *EntityUpsertOne { + ec.conflict = append(ec.conflict, sql.ConflictColumns(columns...)) + return &EntityUpsertOne{ + create: ec, + } +} + +type ( + // EntityUpsertOne is the builder for "upsert"-ing + // one Entity node. + EntityUpsertOne struct { + create *EntityCreate + } + + // EntityUpsert is the "OnConflict" setter. + EntityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *EntityUpsert) SetUpdatedAt(v time.Time) *EntityUpsert { + u.Set(entity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *EntityUpsert) UpdateUpdatedAt() *EntityUpsert { + u.SetExcluded(entity.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *EntityUpsert) SetDeletedAt(v time.Time) *EntityUpsert { + u.Set(entity.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *EntityUpsert) UpdateDeletedAt() *EntityUpsert { + u.SetExcluded(entity.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *EntityUpsert) ClearDeletedAt() *EntityUpsert { + u.SetNull(entity.FieldDeletedAt) + return u +} + +// SetType sets the "type" field. +func (u *EntityUpsert) SetType(v int) *EntityUpsert { + u.Set(entity.FieldType, v) + return u +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *EntityUpsert) UpdateType() *EntityUpsert { + u.SetExcluded(entity.FieldType) + return u +} + +// AddType adds v to the "type" field. +func (u *EntityUpsert) AddType(v int) *EntityUpsert { + u.Add(entity.FieldType, v) + return u +} + +// SetSource sets the "source" field. +func (u *EntityUpsert) SetSource(v string) *EntityUpsert { + u.Set(entity.FieldSource, v) + return u +} + +// UpdateSource sets the "source" field to the value that was provided on create. +func (u *EntityUpsert) UpdateSource() *EntityUpsert { + u.SetExcluded(entity.FieldSource) + return u +} + +// SetSize sets the "size" field. +func (u *EntityUpsert) SetSize(v int64) *EntityUpsert { + u.Set(entity.FieldSize, v) + return u +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *EntityUpsert) UpdateSize() *EntityUpsert { + u.SetExcluded(entity.FieldSize) + return u +} + +// AddSize adds v to the "size" field. +func (u *EntityUpsert) AddSize(v int64) *EntityUpsert { + u.Add(entity.FieldSize, v) + return u +} + +// SetReferenceCount sets the "reference_count" field. +func (u *EntityUpsert) SetReferenceCount(v int) *EntityUpsert { + u.Set(entity.FieldReferenceCount, v) + return u +} + +// UpdateReferenceCount sets the "reference_count" field to the value that was provided on create. +func (u *EntityUpsert) UpdateReferenceCount() *EntityUpsert { + u.SetExcluded(entity.FieldReferenceCount) + return u +} + +// AddReferenceCount adds v to the "reference_count" field. +func (u *EntityUpsert) AddReferenceCount(v int) *EntityUpsert { + u.Add(entity.FieldReferenceCount, v) + return u +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (u *EntityUpsert) SetStoragePolicyEntities(v int) *EntityUpsert { + u.Set(entity.FieldStoragePolicyEntities, v) + return u +} + +// UpdateStoragePolicyEntities sets the "storage_policy_entities" field to the value that was provided on create. +func (u *EntityUpsert) UpdateStoragePolicyEntities() *EntityUpsert { + u.SetExcluded(entity.FieldStoragePolicyEntities) + return u +} + +// SetCreatedBy sets the "created_by" field. +func (u *EntityUpsert) SetCreatedBy(v int) *EntityUpsert { + u.Set(entity.FieldCreatedBy, v) + return u +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *EntityUpsert) UpdateCreatedBy() *EntityUpsert { + u.SetExcluded(entity.FieldCreatedBy) + return u +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (u *EntityUpsert) ClearCreatedBy() *EntityUpsert { + u.SetNull(entity.FieldCreatedBy) + return u +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (u *EntityUpsert) SetUploadSessionID(v uuid.UUID) *EntityUpsert { + u.Set(entity.FieldUploadSessionID, v) + return u +} + +// UpdateUploadSessionID sets the "upload_session_id" field to the value that was provided on create. +func (u *EntityUpsert) UpdateUploadSessionID() *EntityUpsert { + u.SetExcluded(entity.FieldUploadSessionID) + return u +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (u *EntityUpsert) ClearUploadSessionID() *EntityUpsert { + u.SetNull(entity.FieldUploadSessionID) + return u +} + +// SetRecycleOptions sets the "recycle_options" field. +func (u *EntityUpsert) SetRecycleOptions(v *types.EntityRecycleOption) *EntityUpsert { + u.Set(entity.FieldRecycleOptions, v) + return u +} + +// UpdateRecycleOptions sets the "recycle_options" field to the value that was provided on create. +func (u *EntityUpsert) UpdateRecycleOptions() *EntityUpsert { + u.SetExcluded(entity.FieldRecycleOptions) + return u +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (u *EntityUpsert) ClearRecycleOptions() *EntityUpsert { + u.SetNull(entity.FieldRecycleOptions) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *EntityUpsertOne) UpdateNewValues() *EntityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(entity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *EntityUpsertOne) Ignore() *EntityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *EntityUpsertOne) DoNothing() *EntityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the EntityCreate.OnConflict +// documentation for more info. +func (u *EntityUpsertOne) Update(set func(*EntityUpsert)) *EntityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&EntityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *EntityUpsertOne) SetUpdatedAt(v time.Time) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateUpdatedAt() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *EntityUpsertOne) SetDeletedAt(v time.Time) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateDeletedAt() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *EntityUpsertOne) ClearDeletedAt() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *EntityUpsertOne) SetType(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetType(v) + }) +} + +// AddType adds v to the "type" field. +func (u *EntityUpsertOne) AddType(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.AddType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateType() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateType() + }) +} + +// SetSource sets the "source" field. +func (u *EntityUpsertOne) SetSource(v string) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetSource(v) + }) +} + +// UpdateSource sets the "source" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateSource() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateSource() + }) +} + +// SetSize sets the "size" field. +func (u *EntityUpsertOne) SetSize(v int64) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetSize(v) + }) +} + +// AddSize adds v to the "size" field. +func (u *EntityUpsertOne) AddSize(v int64) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.AddSize(v) + }) +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateSize() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateSize() + }) +} + +// SetReferenceCount sets the "reference_count" field. +func (u *EntityUpsertOne) SetReferenceCount(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetReferenceCount(v) + }) +} + +// AddReferenceCount adds v to the "reference_count" field. +func (u *EntityUpsertOne) AddReferenceCount(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.AddReferenceCount(v) + }) +} + +// UpdateReferenceCount sets the "reference_count" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateReferenceCount() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateReferenceCount() + }) +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (u *EntityUpsertOne) SetStoragePolicyEntities(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetStoragePolicyEntities(v) + }) +} + +// UpdateStoragePolicyEntities sets the "storage_policy_entities" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateStoragePolicyEntities() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateStoragePolicyEntities() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *EntityUpsertOne) SetCreatedBy(v int) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateCreatedBy() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateCreatedBy() + }) +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (u *EntityUpsertOne) ClearCreatedBy() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.ClearCreatedBy() + }) +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (u *EntityUpsertOne) SetUploadSessionID(v uuid.UUID) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetUploadSessionID(v) + }) +} + +// UpdateUploadSessionID sets the "upload_session_id" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateUploadSessionID() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateUploadSessionID() + }) +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (u *EntityUpsertOne) ClearUploadSessionID() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.ClearUploadSessionID() + }) +} + +// SetRecycleOptions sets the "recycle_options" field. +func (u *EntityUpsertOne) SetRecycleOptions(v *types.EntityRecycleOption) *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.SetRecycleOptions(v) + }) +} + +// UpdateRecycleOptions sets the "recycle_options" field to the value that was provided on create. +func (u *EntityUpsertOne) UpdateRecycleOptions() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.UpdateRecycleOptions() + }) +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (u *EntityUpsertOne) ClearRecycleOptions() *EntityUpsertOne { + return u.Update(func(s *EntityUpsert) { + s.ClearRecycleOptions() + }) +} + +// Exec executes the query. +func (u *EntityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for EntityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *EntityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *EntityUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *EntityUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *EntityCreate) SetRawID(t int) *EntityCreate { + m.mutation.SetRawID(t) + return m +} + +// EntityCreateBulk is the builder for creating many Entity entities in bulk. +type EntityCreateBulk struct { + config + err error + builders []*EntityCreate + conflict []sql.ConflictOption +} + +// Save creates the Entity entities in the database. +func (ecb *EntityCreateBulk) Save(ctx context.Context) ([]*Entity, error) { + if ecb.err != nil { + return nil, ecb.err + } + specs := make([]*sqlgraph.CreateSpec, len(ecb.builders)) + nodes := make([]*Entity, len(ecb.builders)) + mutators := make([]Mutator, len(ecb.builders)) + for i := range ecb.builders { + func(i int, root context.Context) { + builder := ecb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*EntityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, ecb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = ecb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, ecb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, ecb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (ecb *EntityCreateBulk) SaveX(ctx context.Context) []*Entity { + v, err := ecb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (ecb *EntityCreateBulk) Exec(ctx context.Context) error { + _, err := ecb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ecb *EntityCreateBulk) ExecX(ctx context.Context) { + if err := ecb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Entity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.EntityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (ecb *EntityCreateBulk) OnConflict(opts ...sql.ConflictOption) *EntityUpsertBulk { + ecb.conflict = opts + return &EntityUpsertBulk{ + create: ecb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (ecb *EntityCreateBulk) OnConflictColumns(columns ...string) *EntityUpsertBulk { + ecb.conflict = append(ecb.conflict, sql.ConflictColumns(columns...)) + return &EntityUpsertBulk{ + create: ecb, + } +} + +// EntityUpsertBulk is the builder for "upsert"-ing +// a bulk of Entity nodes. +type EntityUpsertBulk struct { + create *EntityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *EntityUpsertBulk) UpdateNewValues() *EntityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(entity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Entity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *EntityUpsertBulk) Ignore() *EntityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *EntityUpsertBulk) DoNothing() *EntityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the EntityCreateBulk.OnConflict +// documentation for more info. +func (u *EntityUpsertBulk) Update(set func(*EntityUpsert)) *EntityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&EntityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *EntityUpsertBulk) SetUpdatedAt(v time.Time) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateUpdatedAt() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *EntityUpsertBulk) SetDeletedAt(v time.Time) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateDeletedAt() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *EntityUpsertBulk) ClearDeletedAt() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *EntityUpsertBulk) SetType(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetType(v) + }) +} + +// AddType adds v to the "type" field. +func (u *EntityUpsertBulk) AddType(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.AddType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateType() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateType() + }) +} + +// SetSource sets the "source" field. +func (u *EntityUpsertBulk) SetSource(v string) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetSource(v) + }) +} + +// UpdateSource sets the "source" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateSource() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateSource() + }) +} + +// SetSize sets the "size" field. +func (u *EntityUpsertBulk) SetSize(v int64) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetSize(v) + }) +} + +// AddSize adds v to the "size" field. +func (u *EntityUpsertBulk) AddSize(v int64) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.AddSize(v) + }) +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateSize() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateSize() + }) +} + +// SetReferenceCount sets the "reference_count" field. +func (u *EntityUpsertBulk) SetReferenceCount(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetReferenceCount(v) + }) +} + +// AddReferenceCount adds v to the "reference_count" field. +func (u *EntityUpsertBulk) AddReferenceCount(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.AddReferenceCount(v) + }) +} + +// UpdateReferenceCount sets the "reference_count" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateReferenceCount() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateReferenceCount() + }) +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (u *EntityUpsertBulk) SetStoragePolicyEntities(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetStoragePolicyEntities(v) + }) +} + +// UpdateStoragePolicyEntities sets the "storage_policy_entities" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateStoragePolicyEntities() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateStoragePolicyEntities() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *EntityUpsertBulk) SetCreatedBy(v int) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateCreatedBy() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateCreatedBy() + }) +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (u *EntityUpsertBulk) ClearCreatedBy() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.ClearCreatedBy() + }) +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (u *EntityUpsertBulk) SetUploadSessionID(v uuid.UUID) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetUploadSessionID(v) + }) +} + +// UpdateUploadSessionID sets the "upload_session_id" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateUploadSessionID() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateUploadSessionID() + }) +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (u *EntityUpsertBulk) ClearUploadSessionID() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.ClearUploadSessionID() + }) +} + +// SetRecycleOptions sets the "recycle_options" field. +func (u *EntityUpsertBulk) SetRecycleOptions(v *types.EntityRecycleOption) *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.SetRecycleOptions(v) + }) +} + +// UpdateRecycleOptions sets the "recycle_options" field to the value that was provided on create. +func (u *EntityUpsertBulk) UpdateRecycleOptions() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.UpdateRecycleOptions() + }) +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (u *EntityUpsertBulk) ClearRecycleOptions() *EntityUpsertBulk { + return u.Update(func(s *EntityUpsert) { + s.ClearRecycleOptions() + }) +} + +// Exec executes the query. +func (u *EntityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the EntityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for EntityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *EntityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/entity_delete.go b/ent/entity_delete.go new file mode 100644 index 00000000..ee131ceb --- /dev/null +++ b/ent/entity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// EntityDelete is the builder for deleting a Entity entity. +type EntityDelete struct { + config + hooks []Hook + mutation *EntityMutation +} + +// Where appends a list predicates to the EntityDelete builder. +func (ed *EntityDelete) Where(ps ...predicate.Entity) *EntityDelete { + ed.mutation.Where(ps...) + return ed +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ed *EntityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ed.sqlExec, ed.mutation, ed.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ed *EntityDelete) ExecX(ctx context.Context) int { + n, err := ed.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ed *EntityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(entity.Table, sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt)) + if ps := ed.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ed.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ed.mutation.done = true + return affected, err +} + +// EntityDeleteOne is the builder for deleting a single Entity entity. +type EntityDeleteOne struct { + ed *EntityDelete +} + +// Where appends a list predicates to the EntityDelete builder. +func (edo *EntityDeleteOne) Where(ps ...predicate.Entity) *EntityDeleteOne { + edo.ed.mutation.Where(ps...) + return edo +} + +// Exec executes the deletion query. +func (edo *EntityDeleteOne) Exec(ctx context.Context) error { + n, err := edo.ed.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{entity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (edo *EntityDeleteOne) ExecX(ctx context.Context) { + if err := edo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/entity_query.go b/ent/entity_query.go new file mode 100644 index 00000000..259d643e --- /dev/null +++ b/ent/entity_query.go @@ -0,0 +1,786 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// EntityQuery is the builder for querying Entity entities. +type EntityQuery struct { + config + ctx *QueryContext + order []entity.OrderOption + inters []Interceptor + predicates []predicate.Entity + withFile *FileQuery + withUser *UserQuery + withStoragePolicy *StoragePolicyQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the EntityQuery builder. +func (eq *EntityQuery) Where(ps ...predicate.Entity) *EntityQuery { + eq.predicates = append(eq.predicates, ps...) + return eq +} + +// Limit the number of records to be returned by this query. +func (eq *EntityQuery) Limit(limit int) *EntityQuery { + eq.ctx.Limit = &limit + return eq +} + +// Offset to start from. +func (eq *EntityQuery) Offset(offset int) *EntityQuery { + eq.ctx.Offset = &offset + return eq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (eq *EntityQuery) Unique(unique bool) *EntityQuery { + eq.ctx.Unique = &unique + return eq +} + +// Order specifies how the records should be ordered. +func (eq *EntityQuery) Order(o ...entity.OrderOption) *EntityQuery { + eq.order = append(eq.order, o...) + return eq +} + +// QueryFile chains the current query on the "file" edge. +func (eq *EntityQuery) QueryFile() *FileQuery { + query := (&FileClient{config: eq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := eq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := eq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, entity.FileTable, entity.FilePrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(eq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryUser chains the current query on the "user" edge. +func (eq *EntityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: eq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := eq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := eq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, entity.UserTable, entity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(eq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryStoragePolicy chains the current query on the "storage_policy" edge. +func (eq *EntityQuery) QueryStoragePolicy() *StoragePolicyQuery { + query := (&StoragePolicyClient{config: eq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := eq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := eq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(entity.Table, entity.FieldID, selector), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, entity.StoragePolicyTable, entity.StoragePolicyColumn), + ) + fromU = sqlgraph.SetNeighbors(eq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Entity entity from the query. +// Returns a *NotFoundError when no Entity was found. +func (eq *EntityQuery) First(ctx context.Context) (*Entity, error) { + nodes, err := eq.Limit(1).All(setContextOp(ctx, eq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{entity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (eq *EntityQuery) FirstX(ctx context.Context) *Entity { + node, err := eq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Entity ID from the query. +// Returns a *NotFoundError when no Entity ID was found. +func (eq *EntityQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = eq.Limit(1).IDs(setContextOp(ctx, eq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{entity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (eq *EntityQuery) FirstIDX(ctx context.Context) int { + id, err := eq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Entity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Entity entity is found. +// Returns a *NotFoundError when no Entity entities are found. +func (eq *EntityQuery) Only(ctx context.Context) (*Entity, error) { + nodes, err := eq.Limit(2).All(setContextOp(ctx, eq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{entity.Label} + default: + return nil, &NotSingularError{entity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (eq *EntityQuery) OnlyX(ctx context.Context) *Entity { + node, err := eq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Entity ID in the query. +// Returns a *NotSingularError when more than one Entity ID is found. +// Returns a *NotFoundError when no entities are found. +func (eq *EntityQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = eq.Limit(2).IDs(setContextOp(ctx, eq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{entity.Label} + default: + err = &NotSingularError{entity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (eq *EntityQuery) OnlyIDX(ctx context.Context) int { + id, err := eq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Entities. +func (eq *EntityQuery) All(ctx context.Context) ([]*Entity, error) { + ctx = setContextOp(ctx, eq.ctx, "All") + if err := eq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Entity, *EntityQuery]() + return withInterceptors[[]*Entity](ctx, eq, qr, eq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (eq *EntityQuery) AllX(ctx context.Context) []*Entity { + nodes, err := eq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Entity IDs. +func (eq *EntityQuery) IDs(ctx context.Context) (ids []int, err error) { + if eq.ctx.Unique == nil && eq.path != nil { + eq.Unique(true) + } + ctx = setContextOp(ctx, eq.ctx, "IDs") + if err = eq.Select(entity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (eq *EntityQuery) IDsX(ctx context.Context) []int { + ids, err := eq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (eq *EntityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, eq.ctx, "Count") + if err := eq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, eq, querierCount[*EntityQuery](), eq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (eq *EntityQuery) CountX(ctx context.Context) int { + count, err := eq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (eq *EntityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, eq.ctx, "Exist") + switch _, err := eq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (eq *EntityQuery) ExistX(ctx context.Context) bool { + exist, err := eq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the EntityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (eq *EntityQuery) Clone() *EntityQuery { + if eq == nil { + return nil + } + return &EntityQuery{ + config: eq.config, + ctx: eq.ctx.Clone(), + order: append([]entity.OrderOption{}, eq.order...), + inters: append([]Interceptor{}, eq.inters...), + predicates: append([]predicate.Entity{}, eq.predicates...), + withFile: eq.withFile.Clone(), + withUser: eq.withUser.Clone(), + withStoragePolicy: eq.withStoragePolicy.Clone(), + // clone intermediate query. + sql: eq.sql.Clone(), + path: eq.path, + } +} + +// WithFile tells the query-builder to eager-load the nodes that are connected to +// the "file" edge. The optional arguments are used to configure the query builder of the edge. +func (eq *EntityQuery) WithFile(opts ...func(*FileQuery)) *EntityQuery { + query := (&FileClient{config: eq.config}).Query() + for _, opt := range opts { + opt(query) + } + eq.withFile = query + return eq +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (eq *EntityQuery) WithUser(opts ...func(*UserQuery)) *EntityQuery { + query := (&UserClient{config: eq.config}).Query() + for _, opt := range opts { + opt(query) + } + eq.withUser = query + return eq +} + +// WithStoragePolicy tells the query-builder to eager-load the nodes that are connected to +// the "storage_policy" edge. The optional arguments are used to configure the query builder of the edge. +func (eq *EntityQuery) WithStoragePolicy(opts ...func(*StoragePolicyQuery)) *EntityQuery { + query := (&StoragePolicyClient{config: eq.config}).Query() + for _, opt := range opts { + opt(query) + } + eq.withStoragePolicy = query + return eq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Entity.Query(). +// GroupBy(entity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (eq *EntityQuery) GroupBy(field string, fields ...string) *EntityGroupBy { + eq.ctx.Fields = append([]string{field}, fields...) + grbuild := &EntityGroupBy{build: eq} + grbuild.flds = &eq.ctx.Fields + grbuild.label = entity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Entity.Query(). +// Select(entity.FieldCreatedAt). +// Scan(ctx, &v) +func (eq *EntityQuery) Select(fields ...string) *EntitySelect { + eq.ctx.Fields = append(eq.ctx.Fields, fields...) + sbuild := &EntitySelect{EntityQuery: eq} + sbuild.label = entity.Label + sbuild.flds, sbuild.scan = &eq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a EntitySelect configured with the given aggregations. +func (eq *EntityQuery) Aggregate(fns ...AggregateFunc) *EntitySelect { + return eq.Select().Aggregate(fns...) +} + +func (eq *EntityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range eq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, eq); err != nil { + return err + } + } + } + for _, f := range eq.ctx.Fields { + if !entity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if eq.path != nil { + prev, err := eq.path(ctx) + if err != nil { + return err + } + eq.sql = prev + } + return nil +} + +func (eq *EntityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Entity, error) { + var ( + nodes = []*Entity{} + _spec = eq.querySpec() + loadedTypes = [3]bool{ + eq.withFile != nil, + eq.withUser != nil, + eq.withStoragePolicy != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Entity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Entity{config: eq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, eq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := eq.withFile; query != nil { + if err := eq.loadFile(ctx, query, nodes, + func(n *Entity) { n.Edges.File = []*File{} }, + func(n *Entity, e *File) { n.Edges.File = append(n.Edges.File, e) }); err != nil { + return nil, err + } + } + if query := eq.withUser; query != nil { + if err := eq.loadUser(ctx, query, nodes, nil, + func(n *Entity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := eq.withStoragePolicy; query != nil { + if err := eq.loadStoragePolicy(ctx, query, nodes, nil, + func(n *Entity, e *StoragePolicy) { n.Edges.StoragePolicy = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (eq *EntityQuery) loadFile(ctx context.Context, query *FileQuery, nodes []*Entity, init func(*Entity), assign func(*Entity, *File)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Entity) + nids := make(map[int]map[*Entity]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(entity.FileTable) + s.Join(joinT).On(s.C(file.FieldID), joinT.C(entity.FilePrimaryKey[0])) + s.Where(sql.InValues(joinT.C(entity.FilePrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(entity.FilePrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Entity]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*File](ctx, query, qr, query.inters) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "file" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (eq *EntityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Entity, init func(*Entity), assign func(*Entity, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Entity) + for i := range nodes { + fk := nodes[i].CreatedBy + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "created_by" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (eq *EntityQuery) loadStoragePolicy(ctx context.Context, query *StoragePolicyQuery, nodes []*Entity, init func(*Entity), assign func(*Entity, *StoragePolicy)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Entity) + for i := range nodes { + fk := nodes[i].StoragePolicyEntities + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(storagepolicy.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "storage_policy_entities" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (eq *EntityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := eq.querySpec() + _spec.Node.Columns = eq.ctx.Fields + if len(eq.ctx.Fields) > 0 { + _spec.Unique = eq.ctx.Unique != nil && *eq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, eq.driver, _spec) +} + +func (eq *EntityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(entity.Table, entity.Columns, sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt)) + _spec.From = eq.sql + if unique := eq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if eq.path != nil { + _spec.Unique = true + } + if fields := eq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, entity.FieldID) + for i := range fields { + if fields[i] != entity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if eq.withUser != nil { + _spec.Node.AddColumnOnce(entity.FieldCreatedBy) + } + if eq.withStoragePolicy != nil { + _spec.Node.AddColumnOnce(entity.FieldStoragePolicyEntities) + } + } + if ps := eq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := eq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := eq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := eq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (eq *EntityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(eq.driver.Dialect()) + t1 := builder.Table(entity.Table) + columns := eq.ctx.Fields + if len(columns) == 0 { + columns = entity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if eq.sql != nil { + selector = eq.sql + selector.Select(selector.Columns(columns...)...) + } + if eq.ctx.Unique != nil && *eq.ctx.Unique { + selector.Distinct() + } + for _, p := range eq.predicates { + p(selector) + } + for _, p := range eq.order { + p(selector) + } + if offset := eq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := eq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// EntityGroupBy is the group-by builder for Entity entities. +type EntityGroupBy struct { + selector + build *EntityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (egb *EntityGroupBy) Aggregate(fns ...AggregateFunc) *EntityGroupBy { + egb.fns = append(egb.fns, fns...) + return egb +} + +// Scan applies the selector query and scans the result into the given value. +func (egb *EntityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, egb.build.ctx, "GroupBy") + if err := egb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*EntityQuery, *EntityGroupBy](ctx, egb.build, egb, egb.build.inters, v) +} + +func (egb *EntityGroupBy) sqlScan(ctx context.Context, root *EntityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(egb.fns)) + for _, fn := range egb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*egb.flds)+len(egb.fns)) + for _, f := range *egb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*egb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := egb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// EntitySelect is the builder for selecting fields of Entity entities. +type EntitySelect struct { + *EntityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (es *EntitySelect) Aggregate(fns ...AggregateFunc) *EntitySelect { + es.fns = append(es.fns, fns...) + return es +} + +// Scan applies the selector query and scans the result into the given value. +func (es *EntitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, es.ctx, "Select") + if err := es.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*EntityQuery, *EntitySelect](ctx, es.EntityQuery, es, es.inters, v) +} + +func (es *EntitySelect) sqlScan(ctx context.Context, root *EntityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(es.fns)) + for _, fn := range es.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*es.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := es.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/entity_update.go b/ent/entity_update.go new file mode 100644 index 00000000..dfd9c66a --- /dev/null +++ b/ent/entity_update.go @@ -0,0 +1,1017 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// EntityUpdate is the builder for updating Entity entities. +type EntityUpdate struct { + config + hooks []Hook + mutation *EntityMutation +} + +// Where appends a list predicates to the EntityUpdate builder. +func (eu *EntityUpdate) Where(ps ...predicate.Entity) *EntityUpdate { + eu.mutation.Where(ps...) + return eu +} + +// SetUpdatedAt sets the "updated_at" field. +func (eu *EntityUpdate) SetUpdatedAt(t time.Time) *EntityUpdate { + eu.mutation.SetUpdatedAt(t) + return eu +} + +// SetDeletedAt sets the "deleted_at" field. +func (eu *EntityUpdate) SetDeletedAt(t time.Time) *EntityUpdate { + eu.mutation.SetDeletedAt(t) + return eu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableDeletedAt(t *time.Time) *EntityUpdate { + if t != nil { + eu.SetDeletedAt(*t) + } + return eu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (eu *EntityUpdate) ClearDeletedAt() *EntityUpdate { + eu.mutation.ClearDeletedAt() + return eu +} + +// SetType sets the "type" field. +func (eu *EntityUpdate) SetType(i int) *EntityUpdate { + eu.mutation.ResetType() + eu.mutation.SetType(i) + return eu +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableType(i *int) *EntityUpdate { + if i != nil { + eu.SetType(*i) + } + return eu +} + +// AddType adds i to the "type" field. +func (eu *EntityUpdate) AddType(i int) *EntityUpdate { + eu.mutation.AddType(i) + return eu +} + +// SetSource sets the "source" field. +func (eu *EntityUpdate) SetSource(s string) *EntityUpdate { + eu.mutation.SetSource(s) + return eu +} + +// SetNillableSource sets the "source" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableSource(s *string) *EntityUpdate { + if s != nil { + eu.SetSource(*s) + } + return eu +} + +// SetSize sets the "size" field. +func (eu *EntityUpdate) SetSize(i int64) *EntityUpdate { + eu.mutation.ResetSize() + eu.mutation.SetSize(i) + return eu +} + +// SetNillableSize sets the "size" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableSize(i *int64) *EntityUpdate { + if i != nil { + eu.SetSize(*i) + } + return eu +} + +// AddSize adds i to the "size" field. +func (eu *EntityUpdate) AddSize(i int64) *EntityUpdate { + eu.mutation.AddSize(i) + return eu +} + +// SetReferenceCount sets the "reference_count" field. +func (eu *EntityUpdate) SetReferenceCount(i int) *EntityUpdate { + eu.mutation.ResetReferenceCount() + eu.mutation.SetReferenceCount(i) + return eu +} + +// SetNillableReferenceCount sets the "reference_count" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableReferenceCount(i *int) *EntityUpdate { + if i != nil { + eu.SetReferenceCount(*i) + } + return eu +} + +// AddReferenceCount adds i to the "reference_count" field. +func (eu *EntityUpdate) AddReferenceCount(i int) *EntityUpdate { + eu.mutation.AddReferenceCount(i) + return eu +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (eu *EntityUpdate) SetStoragePolicyEntities(i int) *EntityUpdate { + eu.mutation.SetStoragePolicyEntities(i) + return eu +} + +// SetNillableStoragePolicyEntities sets the "storage_policy_entities" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableStoragePolicyEntities(i *int) *EntityUpdate { + if i != nil { + eu.SetStoragePolicyEntities(*i) + } + return eu +} + +// SetCreatedBy sets the "created_by" field. +func (eu *EntityUpdate) SetCreatedBy(i int) *EntityUpdate { + eu.mutation.SetCreatedBy(i) + return eu +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableCreatedBy(i *int) *EntityUpdate { + if i != nil { + eu.SetCreatedBy(*i) + } + return eu +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (eu *EntityUpdate) ClearCreatedBy() *EntityUpdate { + eu.mutation.ClearCreatedBy() + return eu +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (eu *EntityUpdate) SetUploadSessionID(u uuid.UUID) *EntityUpdate { + eu.mutation.SetUploadSessionID(u) + return eu +} + +// SetNillableUploadSessionID sets the "upload_session_id" field if the given value is not nil. +func (eu *EntityUpdate) SetNillableUploadSessionID(u *uuid.UUID) *EntityUpdate { + if u != nil { + eu.SetUploadSessionID(*u) + } + return eu +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (eu *EntityUpdate) ClearUploadSessionID() *EntityUpdate { + eu.mutation.ClearUploadSessionID() + return eu +} + +// SetRecycleOptions sets the "recycle_options" field. +func (eu *EntityUpdate) SetRecycleOptions(tro *types.EntityRecycleOption) *EntityUpdate { + eu.mutation.SetRecycleOptions(tro) + return eu +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (eu *EntityUpdate) ClearRecycleOptions() *EntityUpdate { + eu.mutation.ClearRecycleOptions() + return eu +} + +// AddFileIDs adds the "file" edge to the File entity by IDs. +func (eu *EntityUpdate) AddFileIDs(ids ...int) *EntityUpdate { + eu.mutation.AddFileIDs(ids...) + return eu +} + +// AddFile adds the "file" edges to the File entity. +func (eu *EntityUpdate) AddFile(f ...*File) *EntityUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return eu.AddFileIDs(ids...) +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (eu *EntityUpdate) SetUserID(id int) *EntityUpdate { + eu.mutation.SetUserID(id) + return eu +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (eu *EntityUpdate) SetNillableUserID(id *int) *EntityUpdate { + if id != nil { + eu = eu.SetUserID(*id) + } + return eu +} + +// SetUser sets the "user" edge to the User entity. +func (eu *EntityUpdate) SetUser(u *User) *EntityUpdate { + return eu.SetUserID(u.ID) +} + +// SetStoragePolicyID sets the "storage_policy" edge to the StoragePolicy entity by ID. +func (eu *EntityUpdate) SetStoragePolicyID(id int) *EntityUpdate { + eu.mutation.SetStoragePolicyID(id) + return eu +} + +// SetStoragePolicy sets the "storage_policy" edge to the StoragePolicy entity. +func (eu *EntityUpdate) SetStoragePolicy(s *StoragePolicy) *EntityUpdate { + return eu.SetStoragePolicyID(s.ID) +} + +// Mutation returns the EntityMutation object of the builder. +func (eu *EntityUpdate) Mutation() *EntityMutation { + return eu.mutation +} + +// ClearFile clears all "file" edges to the File entity. +func (eu *EntityUpdate) ClearFile() *EntityUpdate { + eu.mutation.ClearFile() + return eu +} + +// RemoveFileIDs removes the "file" edge to File entities by IDs. +func (eu *EntityUpdate) RemoveFileIDs(ids ...int) *EntityUpdate { + eu.mutation.RemoveFileIDs(ids...) + return eu +} + +// RemoveFile removes "file" edges to File entities. +func (eu *EntityUpdate) RemoveFile(f ...*File) *EntityUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return eu.RemoveFileIDs(ids...) +} + +// ClearUser clears the "user" edge to the User entity. +func (eu *EntityUpdate) ClearUser() *EntityUpdate { + eu.mutation.ClearUser() + return eu +} + +// ClearStoragePolicy clears the "storage_policy" edge to the StoragePolicy entity. +func (eu *EntityUpdate) ClearStoragePolicy() *EntityUpdate { + eu.mutation.ClearStoragePolicy() + return eu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (eu *EntityUpdate) Save(ctx context.Context) (int, error) { + if err := eu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, eu.sqlSave, eu.mutation, eu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (eu *EntityUpdate) SaveX(ctx context.Context) int { + affected, err := eu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (eu *EntityUpdate) Exec(ctx context.Context) error { + _, err := eu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (eu *EntityUpdate) ExecX(ctx context.Context) { + if err := eu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (eu *EntityUpdate) defaults() error { + if _, ok := eu.mutation.UpdatedAt(); !ok { + if entity.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized entity.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := entity.UpdateDefaultUpdatedAt() + eu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (eu *EntityUpdate) check() error { + if _, ok := eu.mutation.StoragePolicyID(); eu.mutation.StoragePolicyCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Entity.storage_policy"`) + } + return nil +} + +func (eu *EntityUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := eu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(entity.Table, entity.Columns, sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt)) + if ps := eu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := eu.mutation.UpdatedAt(); ok { + _spec.SetField(entity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := eu.mutation.DeletedAt(); ok { + _spec.SetField(entity.FieldDeletedAt, field.TypeTime, value) + } + if eu.mutation.DeletedAtCleared() { + _spec.ClearField(entity.FieldDeletedAt, field.TypeTime) + } + if value, ok := eu.mutation.GetType(); ok { + _spec.SetField(entity.FieldType, field.TypeInt, value) + } + if value, ok := eu.mutation.AddedType(); ok { + _spec.AddField(entity.FieldType, field.TypeInt, value) + } + if value, ok := eu.mutation.Source(); ok { + _spec.SetField(entity.FieldSource, field.TypeString, value) + } + if value, ok := eu.mutation.Size(); ok { + _spec.SetField(entity.FieldSize, field.TypeInt64, value) + } + if value, ok := eu.mutation.AddedSize(); ok { + _spec.AddField(entity.FieldSize, field.TypeInt64, value) + } + if value, ok := eu.mutation.ReferenceCount(); ok { + _spec.SetField(entity.FieldReferenceCount, field.TypeInt, value) + } + if value, ok := eu.mutation.AddedReferenceCount(); ok { + _spec.AddField(entity.FieldReferenceCount, field.TypeInt, value) + } + if value, ok := eu.mutation.UploadSessionID(); ok { + _spec.SetField(entity.FieldUploadSessionID, field.TypeUUID, value) + } + if eu.mutation.UploadSessionIDCleared() { + _spec.ClearField(entity.FieldUploadSessionID, field.TypeUUID) + } + if value, ok := eu.mutation.RecycleOptions(); ok { + _spec.SetField(entity.FieldRecycleOptions, field.TypeJSON, value) + } + if eu.mutation.RecycleOptionsCleared() { + _spec.ClearField(entity.FieldRecycleOptions, field.TypeJSON) + } + if eu.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := eu.mutation.RemovedFileIDs(); len(nodes) > 0 && !eu.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := eu.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if eu.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.UserTable, + Columns: []string{entity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := eu.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.UserTable, + Columns: []string{entity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if eu.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.StoragePolicyTable, + Columns: []string{entity.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := eu.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.StoragePolicyTable, + Columns: []string{entity.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, eu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{entity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + eu.mutation.done = true + return n, nil +} + +// EntityUpdateOne is the builder for updating a single Entity entity. +type EntityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *EntityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (euo *EntityUpdateOne) SetUpdatedAt(t time.Time) *EntityUpdateOne { + euo.mutation.SetUpdatedAt(t) + return euo +} + +// SetDeletedAt sets the "deleted_at" field. +func (euo *EntityUpdateOne) SetDeletedAt(t time.Time) *EntityUpdateOne { + euo.mutation.SetDeletedAt(t) + return euo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableDeletedAt(t *time.Time) *EntityUpdateOne { + if t != nil { + euo.SetDeletedAt(*t) + } + return euo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (euo *EntityUpdateOne) ClearDeletedAt() *EntityUpdateOne { + euo.mutation.ClearDeletedAt() + return euo +} + +// SetType sets the "type" field. +func (euo *EntityUpdateOne) SetType(i int) *EntityUpdateOne { + euo.mutation.ResetType() + euo.mutation.SetType(i) + return euo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableType(i *int) *EntityUpdateOne { + if i != nil { + euo.SetType(*i) + } + return euo +} + +// AddType adds i to the "type" field. +func (euo *EntityUpdateOne) AddType(i int) *EntityUpdateOne { + euo.mutation.AddType(i) + return euo +} + +// SetSource sets the "source" field. +func (euo *EntityUpdateOne) SetSource(s string) *EntityUpdateOne { + euo.mutation.SetSource(s) + return euo +} + +// SetNillableSource sets the "source" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableSource(s *string) *EntityUpdateOne { + if s != nil { + euo.SetSource(*s) + } + return euo +} + +// SetSize sets the "size" field. +func (euo *EntityUpdateOne) SetSize(i int64) *EntityUpdateOne { + euo.mutation.ResetSize() + euo.mutation.SetSize(i) + return euo +} + +// SetNillableSize sets the "size" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableSize(i *int64) *EntityUpdateOne { + if i != nil { + euo.SetSize(*i) + } + return euo +} + +// AddSize adds i to the "size" field. +func (euo *EntityUpdateOne) AddSize(i int64) *EntityUpdateOne { + euo.mutation.AddSize(i) + return euo +} + +// SetReferenceCount sets the "reference_count" field. +func (euo *EntityUpdateOne) SetReferenceCount(i int) *EntityUpdateOne { + euo.mutation.ResetReferenceCount() + euo.mutation.SetReferenceCount(i) + return euo +} + +// SetNillableReferenceCount sets the "reference_count" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableReferenceCount(i *int) *EntityUpdateOne { + if i != nil { + euo.SetReferenceCount(*i) + } + return euo +} + +// AddReferenceCount adds i to the "reference_count" field. +func (euo *EntityUpdateOne) AddReferenceCount(i int) *EntityUpdateOne { + euo.mutation.AddReferenceCount(i) + return euo +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (euo *EntityUpdateOne) SetStoragePolicyEntities(i int) *EntityUpdateOne { + euo.mutation.SetStoragePolicyEntities(i) + return euo +} + +// SetNillableStoragePolicyEntities sets the "storage_policy_entities" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableStoragePolicyEntities(i *int) *EntityUpdateOne { + if i != nil { + euo.SetStoragePolicyEntities(*i) + } + return euo +} + +// SetCreatedBy sets the "created_by" field. +func (euo *EntityUpdateOne) SetCreatedBy(i int) *EntityUpdateOne { + euo.mutation.SetCreatedBy(i) + return euo +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableCreatedBy(i *int) *EntityUpdateOne { + if i != nil { + euo.SetCreatedBy(*i) + } + return euo +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (euo *EntityUpdateOne) ClearCreatedBy() *EntityUpdateOne { + euo.mutation.ClearCreatedBy() + return euo +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (euo *EntityUpdateOne) SetUploadSessionID(u uuid.UUID) *EntityUpdateOne { + euo.mutation.SetUploadSessionID(u) + return euo +} + +// SetNillableUploadSessionID sets the "upload_session_id" field if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableUploadSessionID(u *uuid.UUID) *EntityUpdateOne { + if u != nil { + euo.SetUploadSessionID(*u) + } + return euo +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (euo *EntityUpdateOne) ClearUploadSessionID() *EntityUpdateOne { + euo.mutation.ClearUploadSessionID() + return euo +} + +// SetRecycleOptions sets the "recycle_options" field. +func (euo *EntityUpdateOne) SetRecycleOptions(tro *types.EntityRecycleOption) *EntityUpdateOne { + euo.mutation.SetRecycleOptions(tro) + return euo +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (euo *EntityUpdateOne) ClearRecycleOptions() *EntityUpdateOne { + euo.mutation.ClearRecycleOptions() + return euo +} + +// AddFileIDs adds the "file" edge to the File entity by IDs. +func (euo *EntityUpdateOne) AddFileIDs(ids ...int) *EntityUpdateOne { + euo.mutation.AddFileIDs(ids...) + return euo +} + +// AddFile adds the "file" edges to the File entity. +func (euo *EntityUpdateOne) AddFile(f ...*File) *EntityUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return euo.AddFileIDs(ids...) +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (euo *EntityUpdateOne) SetUserID(id int) *EntityUpdateOne { + euo.mutation.SetUserID(id) + return euo +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (euo *EntityUpdateOne) SetNillableUserID(id *int) *EntityUpdateOne { + if id != nil { + euo = euo.SetUserID(*id) + } + return euo +} + +// SetUser sets the "user" edge to the User entity. +func (euo *EntityUpdateOne) SetUser(u *User) *EntityUpdateOne { + return euo.SetUserID(u.ID) +} + +// SetStoragePolicyID sets the "storage_policy" edge to the StoragePolicy entity by ID. +func (euo *EntityUpdateOne) SetStoragePolicyID(id int) *EntityUpdateOne { + euo.mutation.SetStoragePolicyID(id) + return euo +} + +// SetStoragePolicy sets the "storage_policy" edge to the StoragePolicy entity. +func (euo *EntityUpdateOne) SetStoragePolicy(s *StoragePolicy) *EntityUpdateOne { + return euo.SetStoragePolicyID(s.ID) +} + +// Mutation returns the EntityMutation object of the builder. +func (euo *EntityUpdateOne) Mutation() *EntityMutation { + return euo.mutation +} + +// ClearFile clears all "file" edges to the File entity. +func (euo *EntityUpdateOne) ClearFile() *EntityUpdateOne { + euo.mutation.ClearFile() + return euo +} + +// RemoveFileIDs removes the "file" edge to File entities by IDs. +func (euo *EntityUpdateOne) RemoveFileIDs(ids ...int) *EntityUpdateOne { + euo.mutation.RemoveFileIDs(ids...) + return euo +} + +// RemoveFile removes "file" edges to File entities. +func (euo *EntityUpdateOne) RemoveFile(f ...*File) *EntityUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return euo.RemoveFileIDs(ids...) +} + +// ClearUser clears the "user" edge to the User entity. +func (euo *EntityUpdateOne) ClearUser() *EntityUpdateOne { + euo.mutation.ClearUser() + return euo +} + +// ClearStoragePolicy clears the "storage_policy" edge to the StoragePolicy entity. +func (euo *EntityUpdateOne) ClearStoragePolicy() *EntityUpdateOne { + euo.mutation.ClearStoragePolicy() + return euo +} + +// Where appends a list predicates to the EntityUpdate builder. +func (euo *EntityUpdateOne) Where(ps ...predicate.Entity) *EntityUpdateOne { + euo.mutation.Where(ps...) + return euo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (euo *EntityUpdateOne) Select(field string, fields ...string) *EntityUpdateOne { + euo.fields = append([]string{field}, fields...) + return euo +} + +// Save executes the query and returns the updated Entity entity. +func (euo *EntityUpdateOne) Save(ctx context.Context) (*Entity, error) { + if err := euo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, euo.sqlSave, euo.mutation, euo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (euo *EntityUpdateOne) SaveX(ctx context.Context) *Entity { + node, err := euo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (euo *EntityUpdateOne) Exec(ctx context.Context) error { + _, err := euo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (euo *EntityUpdateOne) ExecX(ctx context.Context) { + if err := euo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (euo *EntityUpdateOne) defaults() error { + if _, ok := euo.mutation.UpdatedAt(); !ok { + if entity.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized entity.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := entity.UpdateDefaultUpdatedAt() + euo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (euo *EntityUpdateOne) check() error { + if _, ok := euo.mutation.StoragePolicyID(); euo.mutation.StoragePolicyCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Entity.storage_policy"`) + } + return nil +} + +func (euo *EntityUpdateOne) sqlSave(ctx context.Context) (_node *Entity, err error) { + if err := euo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(entity.Table, entity.Columns, sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt)) + id, ok := euo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Entity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := euo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, entity.FieldID) + for _, f := range fields { + if !entity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != entity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := euo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := euo.mutation.UpdatedAt(); ok { + _spec.SetField(entity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := euo.mutation.DeletedAt(); ok { + _spec.SetField(entity.FieldDeletedAt, field.TypeTime, value) + } + if euo.mutation.DeletedAtCleared() { + _spec.ClearField(entity.FieldDeletedAt, field.TypeTime) + } + if value, ok := euo.mutation.GetType(); ok { + _spec.SetField(entity.FieldType, field.TypeInt, value) + } + if value, ok := euo.mutation.AddedType(); ok { + _spec.AddField(entity.FieldType, field.TypeInt, value) + } + if value, ok := euo.mutation.Source(); ok { + _spec.SetField(entity.FieldSource, field.TypeString, value) + } + if value, ok := euo.mutation.Size(); ok { + _spec.SetField(entity.FieldSize, field.TypeInt64, value) + } + if value, ok := euo.mutation.AddedSize(); ok { + _spec.AddField(entity.FieldSize, field.TypeInt64, value) + } + if value, ok := euo.mutation.ReferenceCount(); ok { + _spec.SetField(entity.FieldReferenceCount, field.TypeInt, value) + } + if value, ok := euo.mutation.AddedReferenceCount(); ok { + _spec.AddField(entity.FieldReferenceCount, field.TypeInt, value) + } + if value, ok := euo.mutation.UploadSessionID(); ok { + _spec.SetField(entity.FieldUploadSessionID, field.TypeUUID, value) + } + if euo.mutation.UploadSessionIDCleared() { + _spec.ClearField(entity.FieldUploadSessionID, field.TypeUUID) + } + if value, ok := euo.mutation.RecycleOptions(); ok { + _spec.SetField(entity.FieldRecycleOptions, field.TypeJSON, value) + } + if euo.mutation.RecycleOptionsCleared() { + _spec.ClearField(entity.FieldRecycleOptions, field.TypeJSON) + } + if euo.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := euo.mutation.RemovedFileIDs(); len(nodes) > 0 && !euo.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := euo.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: entity.FileTable, + Columns: entity.FilePrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if euo.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.UserTable, + Columns: []string{entity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := euo.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.UserTable, + Columns: []string{entity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if euo.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.StoragePolicyTable, + Columns: []string{entity.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := euo.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: entity.StoragePolicyTable, + Columns: []string{entity.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Entity{config: euo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, euo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{entity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + euo.mutation.done = true + return _node, nil +} diff --git a/ent/enttest/enttest.go b/ent/enttest/enttest.go new file mode 100644 index 00000000..3e339dbd --- /dev/null +++ b/ent/enttest/enttest.go @@ -0,0 +1,84 @@ +// Code generated by ent, DO NOT EDIT. + +package enttest + +import ( + "context" + + "github.com/cloudreve/Cloudreve/v4/ent" + // required by schema hooks. + _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" + + "entgo.io/ent/dialect/sql/schema" + "github.com/cloudreve/Cloudreve/v4/ent/migrate" +) + +type ( + // TestingT is the interface that is shared between + // testing.T and testing.B and used by enttest. + TestingT interface { + FailNow() + Error(...any) + } + + // Option configures client creation. + Option func(*options) + + options struct { + opts []ent.Option + migrateOpts []schema.MigrateOption + } +) + +// WithOptions forwards options to client creation. +func WithOptions(opts ...ent.Option) Option { + return func(o *options) { + o.opts = append(o.opts, opts...) + } +} + +// WithMigrateOptions forwards options to auto migration. +func WithMigrateOptions(opts ...schema.MigrateOption) Option { + return func(o *options) { + o.migrateOpts = append(o.migrateOpts, opts...) + } +} + +func newOptions(opts []Option) *options { + o := &options{} + for _, opt := range opts { + opt(o) + } + return o +} + +// Open calls ent.Open and auto-run migration. +func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { + o := newOptions(opts) + c, err := ent.Open(driverName, dataSourceName, o.opts...) + if err != nil { + t.Error(err) + t.FailNow() + } + migrateSchema(t, c, o) + return c +} + +// NewClient calls ent.NewClient and auto-run migration. +func NewClient(t TestingT, opts ...Option) *ent.Client { + o := newOptions(opts) + c := ent.NewClient(o.opts...) + migrateSchema(t, c, o) + return c +} +func migrateSchema(t TestingT, c *ent.Client, o *options) { + tables, err := schema.CopyTables(migrate.Tables) + if err != nil { + t.Error(err) + t.FailNow() + } + if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { + t.Error(err) + t.FailNow() + } +} diff --git a/ent/file.go b/ent/file.go new file mode 100644 index 00000000..e92ede44 --- /dev/null +++ b/ent/file.go @@ -0,0 +1,438 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// File is the model entity for the File schema. +type File struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Type holds the value of the "type" field. + Type int `json:"type,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // OwnerID holds the value of the "owner_id" field. + OwnerID int `json:"owner_id,omitempty"` + // Size holds the value of the "size" field. + Size int64 `json:"size,omitempty"` + // PrimaryEntity holds the value of the "primary_entity" field. + PrimaryEntity int `json:"primary_entity,omitempty"` + // FileChildren holds the value of the "file_children" field. + FileChildren int `json:"file_children,omitempty"` + // IsSymbolic holds the value of the "is_symbolic" field. + IsSymbolic bool `json:"is_symbolic,omitempty"` + // Props holds the value of the "props" field. + Props *types.FileProps `json:"props,omitempty"` + // StoragePolicyFiles holds the value of the "storage_policy_files" field. + StoragePolicyFiles int `json:"storage_policy_files,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the FileQuery when eager-loading is set. + Edges FileEdges `json:"edges"` + selectValues sql.SelectValues +} + +// FileEdges holds the relations/edges for other nodes in the graph. +type FileEdges struct { + // Owner holds the value of the owner edge. + Owner *User `json:"owner,omitempty"` + // StoragePolicies holds the value of the storage_policies edge. + StoragePolicies *StoragePolicy `json:"storage_policies,omitempty"` + // Parent holds the value of the parent edge. + Parent *File `json:"parent,omitempty"` + // Children holds the value of the children edge. + Children []*File `json:"children,omitempty"` + // Metadata holds the value of the metadata edge. + Metadata []*Metadata `json:"metadata,omitempty"` + // Entities holds the value of the entities edge. + Entities []*Entity `json:"entities,omitempty"` + // Shares holds the value of the shares edge. + Shares []*Share `json:"shares,omitempty"` + // DirectLinks holds the value of the direct_links edge. + DirectLinks []*DirectLink `json:"direct_links,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [8]bool +} + +// OwnerOrErr returns the Owner value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e FileEdges) OwnerOrErr() (*User, error) { + if e.loadedTypes[0] { + if e.Owner == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.Owner, nil + } + return nil, &NotLoadedError{edge: "owner"} +} + +// StoragePoliciesOrErr returns the StoragePolicies value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e FileEdges) StoragePoliciesOrErr() (*StoragePolicy, error) { + if e.loadedTypes[1] { + if e.StoragePolicies == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: storagepolicy.Label} + } + return e.StoragePolicies, nil + } + return nil, &NotLoadedError{edge: "storage_policies"} +} + +// ParentOrErr returns the Parent value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e FileEdges) ParentOrErr() (*File, error) { + if e.loadedTypes[2] { + if e.Parent == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: file.Label} + } + return e.Parent, nil + } + return nil, &NotLoadedError{edge: "parent"} +} + +// ChildrenOrErr returns the Children value or an error if the edge +// was not loaded in eager-loading. +func (e FileEdges) ChildrenOrErr() ([]*File, error) { + if e.loadedTypes[3] { + return e.Children, nil + } + return nil, &NotLoadedError{edge: "children"} +} + +// MetadataOrErr returns the Metadata value or an error if the edge +// was not loaded in eager-loading. +func (e FileEdges) MetadataOrErr() ([]*Metadata, error) { + if e.loadedTypes[4] { + return e.Metadata, nil + } + return nil, &NotLoadedError{edge: "metadata"} +} + +// EntitiesOrErr returns the Entities value or an error if the edge +// was not loaded in eager-loading. +func (e FileEdges) EntitiesOrErr() ([]*Entity, error) { + if e.loadedTypes[5] { + return e.Entities, nil + } + return nil, &NotLoadedError{edge: "entities"} +} + +// SharesOrErr returns the Shares value or an error if the edge +// was not loaded in eager-loading. +func (e FileEdges) SharesOrErr() ([]*Share, error) { + if e.loadedTypes[6] { + return e.Shares, nil + } + return nil, &NotLoadedError{edge: "shares"} +} + +// DirectLinksOrErr returns the DirectLinks value or an error if the edge +// was not loaded in eager-loading. +func (e FileEdges) DirectLinksOrErr() ([]*DirectLink, error) { + if e.loadedTypes[7] { + return e.DirectLinks, nil + } + return nil, &NotLoadedError{edge: "direct_links"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*File) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case file.FieldProps: + values[i] = new([]byte) + case file.FieldIsSymbolic: + values[i] = new(sql.NullBool) + case file.FieldID, file.FieldType, file.FieldOwnerID, file.FieldSize, file.FieldPrimaryEntity, file.FieldFileChildren, file.FieldStoragePolicyFiles: + values[i] = new(sql.NullInt64) + case file.FieldName: + values[i] = new(sql.NullString) + case file.FieldCreatedAt, file.FieldUpdatedAt, file.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the File fields. +func (f *File) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case file.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + f.ID = int(value.Int64) + case file.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + f.CreatedAt = value.Time + } + case file.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + f.UpdatedAt = value.Time + } + case file.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + f.DeletedAt = new(time.Time) + *f.DeletedAt = value.Time + } + case file.FieldType: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + f.Type = int(value.Int64) + } + case file.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + f.Name = value.String + } + case file.FieldOwnerID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field owner_id", values[i]) + } else if value.Valid { + f.OwnerID = int(value.Int64) + } + case file.FieldSize: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field size", values[i]) + } else if value.Valid { + f.Size = value.Int64 + } + case file.FieldPrimaryEntity: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field primary_entity", values[i]) + } else if value.Valid { + f.PrimaryEntity = int(value.Int64) + } + case file.FieldFileChildren: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field file_children", values[i]) + } else if value.Valid { + f.FileChildren = int(value.Int64) + } + case file.FieldIsSymbolic: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_symbolic", values[i]) + } else if value.Valid { + f.IsSymbolic = value.Bool + } + case file.FieldProps: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field props", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &f.Props); err != nil { + return fmt.Errorf("unmarshal field props: %w", err) + } + } + case file.FieldStoragePolicyFiles: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field storage_policy_files", values[i]) + } else if value.Valid { + f.StoragePolicyFiles = int(value.Int64) + } + default: + f.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the File. +// This includes values selected through modifiers, order, etc. +func (f *File) Value(name string) (ent.Value, error) { + return f.selectValues.Get(name) +} + +// QueryOwner queries the "owner" edge of the File entity. +func (f *File) QueryOwner() *UserQuery { + return NewFileClient(f.config).QueryOwner(f) +} + +// QueryStoragePolicies queries the "storage_policies" edge of the File entity. +func (f *File) QueryStoragePolicies() *StoragePolicyQuery { + return NewFileClient(f.config).QueryStoragePolicies(f) +} + +// QueryParent queries the "parent" edge of the File entity. +func (f *File) QueryParent() *FileQuery { + return NewFileClient(f.config).QueryParent(f) +} + +// QueryChildren queries the "children" edge of the File entity. +func (f *File) QueryChildren() *FileQuery { + return NewFileClient(f.config).QueryChildren(f) +} + +// QueryMetadata queries the "metadata" edge of the File entity. +func (f *File) QueryMetadata() *MetadataQuery { + return NewFileClient(f.config).QueryMetadata(f) +} + +// QueryEntities queries the "entities" edge of the File entity. +func (f *File) QueryEntities() *EntityQuery { + return NewFileClient(f.config).QueryEntities(f) +} + +// QueryShares queries the "shares" edge of the File entity. +func (f *File) QueryShares() *ShareQuery { + return NewFileClient(f.config).QueryShares(f) +} + +// QueryDirectLinks queries the "direct_links" edge of the File entity. +func (f *File) QueryDirectLinks() *DirectLinkQuery { + return NewFileClient(f.config).QueryDirectLinks(f) +} + +// Update returns a builder for updating this File. +// Note that you need to call File.Unwrap() before calling this method if this File +// was returned from a transaction, and the transaction was committed or rolled back. +func (f *File) Update() *FileUpdateOne { + return NewFileClient(f.config).UpdateOne(f) +} + +// Unwrap unwraps the File entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (f *File) Unwrap() *File { + _tx, ok := f.config.driver.(*txDriver) + if !ok { + panic("ent: File is not a transactional entity") + } + f.config.driver = _tx.drv + return f +} + +// String implements the fmt.Stringer. +func (f *File) String() string { + var builder strings.Builder + builder.WriteString("File(") + builder.WriteString(fmt.Sprintf("id=%v, ", f.ID)) + builder.WriteString("created_at=") + builder.WriteString(f.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(f.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := f.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("type=") + builder.WriteString(fmt.Sprintf("%v", f.Type)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(f.Name) + builder.WriteString(", ") + builder.WriteString("owner_id=") + builder.WriteString(fmt.Sprintf("%v", f.OwnerID)) + builder.WriteString(", ") + builder.WriteString("size=") + builder.WriteString(fmt.Sprintf("%v", f.Size)) + builder.WriteString(", ") + builder.WriteString("primary_entity=") + builder.WriteString(fmt.Sprintf("%v", f.PrimaryEntity)) + builder.WriteString(", ") + builder.WriteString("file_children=") + builder.WriteString(fmt.Sprintf("%v", f.FileChildren)) + builder.WriteString(", ") + builder.WriteString("is_symbolic=") + builder.WriteString(fmt.Sprintf("%v", f.IsSymbolic)) + builder.WriteString(", ") + builder.WriteString("props=") + builder.WriteString(fmt.Sprintf("%v", f.Props)) + builder.WriteString(", ") + builder.WriteString("storage_policy_files=") + builder.WriteString(fmt.Sprintf("%v", f.StoragePolicyFiles)) + builder.WriteByte(')') + return builder.String() +} + +// SetOwner manually set the edge as loaded state. +func (e *File) SetOwner(v *User) { + e.Edges.Owner = v + e.Edges.loadedTypes[0] = true +} + +// SetStoragePolicies manually set the edge as loaded state. +func (e *File) SetStoragePolicies(v *StoragePolicy) { + e.Edges.StoragePolicies = v + e.Edges.loadedTypes[1] = true +} + +// SetParent manually set the edge as loaded state. +func (e *File) SetParent(v *File) { + e.Edges.Parent = v + e.Edges.loadedTypes[2] = true +} + +// SetChildren manually set the edge as loaded state. +func (e *File) SetChildren(v []*File) { + e.Edges.Children = v + e.Edges.loadedTypes[3] = true +} + +// SetMetadata manually set the edge as loaded state. +func (e *File) SetMetadata(v []*Metadata) { + e.Edges.Metadata = v + e.Edges.loadedTypes[4] = true +} + +// SetEntities manually set the edge as loaded state. +func (e *File) SetEntities(v []*Entity) { + e.Edges.Entities = v + e.Edges.loadedTypes[5] = true +} + +// SetShares manually set the edge as loaded state. +func (e *File) SetShares(v []*Share) { + e.Edges.Shares = v + e.Edges.loadedTypes[6] = true +} + +// SetDirectLinks manually set the edge as loaded state. +func (e *File) SetDirectLinks(v []*DirectLink) { + e.Edges.DirectLinks = v + e.Edges.loadedTypes[7] = true +} + +// Files is a parsable slice of File. +type Files []*File diff --git a/ent/file/file.go b/ent/file/file.go new file mode 100644 index 00000000..6adc3dae --- /dev/null +++ b/ent/file/file.go @@ -0,0 +1,371 @@ +// Code generated by ent, DO NOT EDIT. + +package file + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the file type in the database. + Label = "file" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldOwnerID holds the string denoting the owner_id field in the database. + FieldOwnerID = "owner_id" + // FieldSize holds the string denoting the size field in the database. + FieldSize = "size" + // FieldPrimaryEntity holds the string denoting the primary_entity field in the database. + FieldPrimaryEntity = "primary_entity" + // FieldFileChildren holds the string denoting the file_children field in the database. + FieldFileChildren = "file_children" + // FieldIsSymbolic holds the string denoting the is_symbolic field in the database. + FieldIsSymbolic = "is_symbolic" + // FieldProps holds the string denoting the props field in the database. + FieldProps = "props" + // FieldStoragePolicyFiles holds the string denoting the storage_policy_files field in the database. + FieldStoragePolicyFiles = "storage_policy_files" + // EdgeOwner holds the string denoting the owner edge name in mutations. + EdgeOwner = "owner" + // EdgeStoragePolicies holds the string denoting the storage_policies edge name in mutations. + EdgeStoragePolicies = "storage_policies" + // EdgeParent holds the string denoting the parent edge name in mutations. + EdgeParent = "parent" + // EdgeChildren holds the string denoting the children edge name in mutations. + EdgeChildren = "children" + // EdgeMetadata holds the string denoting the metadata edge name in mutations. + EdgeMetadata = "metadata" + // EdgeEntities holds the string denoting the entities edge name in mutations. + EdgeEntities = "entities" + // EdgeShares holds the string denoting the shares edge name in mutations. + EdgeShares = "shares" + // EdgeDirectLinks holds the string denoting the direct_links edge name in mutations. + EdgeDirectLinks = "direct_links" + // Table holds the table name of the file in the database. + Table = "files" + // OwnerTable is the table that holds the owner relation/edge. + OwnerTable = "files" + // OwnerInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + OwnerInverseTable = "users" + // OwnerColumn is the table column denoting the owner relation/edge. + OwnerColumn = "owner_id" + // StoragePoliciesTable is the table that holds the storage_policies relation/edge. + StoragePoliciesTable = "files" + // StoragePoliciesInverseTable is the table name for the StoragePolicy entity. + // It exists in this package in order to avoid circular dependency with the "storagepolicy" package. + StoragePoliciesInverseTable = "storage_policies" + // StoragePoliciesColumn is the table column denoting the storage_policies relation/edge. + StoragePoliciesColumn = "storage_policy_files" + // ParentTable is the table that holds the parent relation/edge. + ParentTable = "files" + // ParentColumn is the table column denoting the parent relation/edge. + ParentColumn = "file_children" + // ChildrenTable is the table that holds the children relation/edge. + ChildrenTable = "files" + // ChildrenColumn is the table column denoting the children relation/edge. + ChildrenColumn = "file_children" + // MetadataTable is the table that holds the metadata relation/edge. + MetadataTable = "metadata" + // MetadataInverseTable is the table name for the Metadata entity. + // It exists in this package in order to avoid circular dependency with the "metadata" package. + MetadataInverseTable = "metadata" + // MetadataColumn is the table column denoting the metadata relation/edge. + MetadataColumn = "file_id" + // EntitiesTable is the table that holds the entities relation/edge. The primary key declared below. + EntitiesTable = "file_entities" + // EntitiesInverseTable is the table name for the Entity entity. + // It exists in this package in order to avoid circular dependency with the "entity" package. + EntitiesInverseTable = "entities" + // SharesTable is the table that holds the shares relation/edge. + SharesTable = "shares" + // SharesInverseTable is the table name for the Share entity. + // It exists in this package in order to avoid circular dependency with the "share" package. + SharesInverseTable = "shares" + // SharesColumn is the table column denoting the shares relation/edge. + SharesColumn = "file_shares" + // DirectLinksTable is the table that holds the direct_links relation/edge. + DirectLinksTable = "direct_links" + // DirectLinksInverseTable is the table name for the DirectLink entity. + // It exists in this package in order to avoid circular dependency with the "directlink" package. + DirectLinksInverseTable = "direct_links" + // DirectLinksColumn is the table column denoting the direct_links relation/edge. + DirectLinksColumn = "file_id" +) + +// Columns holds all SQL columns for file fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldType, + FieldName, + FieldOwnerID, + FieldSize, + FieldPrimaryEntity, + FieldFileChildren, + FieldIsSymbolic, + FieldProps, + FieldStoragePolicyFiles, +} + +var ( + // EntitiesPrimaryKey and EntitiesColumn2 are the table columns denoting the + // primary key for the entities relation (M2M). + EntitiesPrimaryKey = []string{"file_id", "entity_id"} +) + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultSize holds the default value on creation for the "size" field. + DefaultSize int64 + // DefaultIsSymbolic holds the default value on creation for the "is_symbolic" field. + DefaultIsSymbolic bool +) + +// OrderOption defines the ordering options for the File queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByOwnerID orders the results by the owner_id field. +func ByOwnerID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOwnerID, opts...).ToFunc() +} + +// BySize orders the results by the size field. +func BySize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSize, opts...).ToFunc() +} + +// ByPrimaryEntity orders the results by the primary_entity field. +func ByPrimaryEntity(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrimaryEntity, opts...).ToFunc() +} + +// ByFileChildren orders the results by the file_children field. +func ByFileChildren(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFileChildren, opts...).ToFunc() +} + +// ByIsSymbolic orders the results by the is_symbolic field. +func ByIsSymbolic(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsSymbolic, opts...).ToFunc() +} + +// ByStoragePolicyFiles orders the results by the storage_policy_files field. +func ByStoragePolicyFiles(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoragePolicyFiles, opts...).ToFunc() +} + +// ByOwnerField orders the results by owner field. +func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...)) + } +} + +// ByStoragePoliciesField orders the results by storage_policies field. +func ByStoragePoliciesField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStoragePoliciesStep(), sql.OrderByField(field, opts...)) + } +} + +// ByParentField orders the results by parent field. +func ByParentField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newParentStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChildrenCount orders the results by children count. +func ByChildrenCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChildrenStep(), opts...) + } +} + +// ByChildren orders the results by children terms. +func ByChildren(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChildrenStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByMetadataCount orders the results by metadata count. +func ByMetadataCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMetadataStep(), opts...) + } +} + +// ByMetadata orders the results by metadata terms. +func ByMetadata(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMetadataStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEntitiesCount orders the results by entities count. +func ByEntitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEntitiesStep(), opts...) + } +} + +// ByEntities orders the results by entities terms. +func ByEntities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEntitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// BySharesCount orders the results by shares count. +func BySharesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSharesStep(), opts...) + } +} + +// ByShares orders the results by shares terms. +func ByShares(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSharesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByDirectLinksCount orders the results by direct_links count. +func ByDirectLinksCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDirectLinksStep(), opts...) + } +} + +// ByDirectLinks orders the results by direct_links terms. +func ByDirectLinks(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDirectLinksStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newOwnerStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(OwnerInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) +} +func newStoragePoliciesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StoragePoliciesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePoliciesTable, StoragePoliciesColumn), + ) +} +func newParentStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ParentTable, ParentColumn), + ) +} +func newChildrenStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChildrenTable, ChildrenColumn), + ) +} +func newMetadataStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MetadataInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, MetadataTable, MetadataColumn), + ) +} +func newEntitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EntitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, EntitiesTable, EntitiesPrimaryKey...), + ) +} +func newSharesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SharesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SharesTable, SharesColumn), + ) +} +func newDirectLinksStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DirectLinksInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DirectLinksTable, DirectLinksColumn), + ) +} diff --git a/ent/file/where.go b/ent/file/where.go new file mode 100644 index 00000000..f2a08bbf --- /dev/null +++ b/ent/file/where.go @@ -0,0 +1,735 @@ +// Code generated by ent, DO NOT EDIT. + +package file + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.File { + return predicate.File(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.File { + return predicate.File(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.File { + return predicate.File(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.File { + return predicate.File(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.File { + return predicate.File(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.File { + return predicate.File(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Type applies equality check predicate on the "type" field. It's identical to TypeEQ. +func Type(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldType, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.File { + return predicate.File(sql.FieldEQ(FieldName, v)) +} + +// OwnerID applies equality check predicate on the "owner_id" field. It's identical to OwnerIDEQ. +func OwnerID(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldOwnerID, v)) +} + +// Size applies equality check predicate on the "size" field. It's identical to SizeEQ. +func Size(v int64) predicate.File { + return predicate.File(sql.FieldEQ(FieldSize, v)) +} + +// PrimaryEntity applies equality check predicate on the "primary_entity" field. It's identical to PrimaryEntityEQ. +func PrimaryEntity(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldPrimaryEntity, v)) +} + +// FileChildren applies equality check predicate on the "file_children" field. It's identical to FileChildrenEQ. +func FileChildren(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldFileChildren, v)) +} + +// IsSymbolic applies equality check predicate on the "is_symbolic" field. It's identical to IsSymbolicEQ. +func IsSymbolic(v bool) predicate.File { + return predicate.File(sql.FieldEQ(FieldIsSymbolic, v)) +} + +// StoragePolicyFiles applies equality check predicate on the "storage_policy_files" field. It's identical to StoragePolicyFilesEQ. +func StoragePolicyFiles(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldStoragePolicyFiles, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.File { + return predicate.File(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.File { + return predicate.File(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.File { + return predicate.File(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.File { + return predicate.File(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.File { + return predicate.File(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.File { + return predicate.File(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.File { + return predicate.File(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.File { + return predicate.File(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.File { + return predicate.File(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.File { + return predicate.File(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.File { + return predicate.File(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.File { + return predicate.File(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.File { + return predicate.File(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.File { + return predicate.File(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.File { + return predicate.File(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.File { + return predicate.File(sql.FieldNotNull(FieldDeletedAt)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldType, vs...)) +} + +// TypeGT applies the GT predicate on the "type" field. +func TypeGT(v int) predicate.File { + return predicate.File(sql.FieldGT(FieldType, v)) +} + +// TypeGTE applies the GTE predicate on the "type" field. +func TypeGTE(v int) predicate.File { + return predicate.File(sql.FieldGTE(FieldType, v)) +} + +// TypeLT applies the LT predicate on the "type" field. +func TypeLT(v int) predicate.File { + return predicate.File(sql.FieldLT(FieldType, v)) +} + +// TypeLTE applies the LTE predicate on the "type" field. +func TypeLTE(v int) predicate.File { + return predicate.File(sql.FieldLTE(FieldType, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.File { + return predicate.File(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.File { + return predicate.File(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.File { + return predicate.File(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.File { + return predicate.File(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.File { + return predicate.File(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.File { + return predicate.File(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.File { + return predicate.File(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.File { + return predicate.File(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.File { + return predicate.File(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.File { + return predicate.File(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.File { + return predicate.File(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.File { + return predicate.File(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.File { + return predicate.File(sql.FieldContainsFold(FieldName, v)) +} + +// OwnerIDEQ applies the EQ predicate on the "owner_id" field. +func OwnerIDEQ(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldOwnerID, v)) +} + +// OwnerIDNEQ applies the NEQ predicate on the "owner_id" field. +func OwnerIDNEQ(v int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldOwnerID, v)) +} + +// OwnerIDIn applies the In predicate on the "owner_id" field. +func OwnerIDIn(vs ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldOwnerID, vs...)) +} + +// OwnerIDNotIn applies the NotIn predicate on the "owner_id" field. +func OwnerIDNotIn(vs ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldOwnerID, vs...)) +} + +// SizeEQ applies the EQ predicate on the "size" field. +func SizeEQ(v int64) predicate.File { + return predicate.File(sql.FieldEQ(FieldSize, v)) +} + +// SizeNEQ applies the NEQ predicate on the "size" field. +func SizeNEQ(v int64) predicate.File { + return predicate.File(sql.FieldNEQ(FieldSize, v)) +} + +// SizeIn applies the In predicate on the "size" field. +func SizeIn(vs ...int64) predicate.File { + return predicate.File(sql.FieldIn(FieldSize, vs...)) +} + +// SizeNotIn applies the NotIn predicate on the "size" field. +func SizeNotIn(vs ...int64) predicate.File { + return predicate.File(sql.FieldNotIn(FieldSize, vs...)) +} + +// SizeGT applies the GT predicate on the "size" field. +func SizeGT(v int64) predicate.File { + return predicate.File(sql.FieldGT(FieldSize, v)) +} + +// SizeGTE applies the GTE predicate on the "size" field. +func SizeGTE(v int64) predicate.File { + return predicate.File(sql.FieldGTE(FieldSize, v)) +} + +// SizeLT applies the LT predicate on the "size" field. +func SizeLT(v int64) predicate.File { + return predicate.File(sql.FieldLT(FieldSize, v)) +} + +// SizeLTE applies the LTE predicate on the "size" field. +func SizeLTE(v int64) predicate.File { + return predicate.File(sql.FieldLTE(FieldSize, v)) +} + +// PrimaryEntityEQ applies the EQ predicate on the "primary_entity" field. +func PrimaryEntityEQ(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldPrimaryEntity, v)) +} + +// PrimaryEntityNEQ applies the NEQ predicate on the "primary_entity" field. +func PrimaryEntityNEQ(v int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldPrimaryEntity, v)) +} + +// PrimaryEntityIn applies the In predicate on the "primary_entity" field. +func PrimaryEntityIn(vs ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldPrimaryEntity, vs...)) +} + +// PrimaryEntityNotIn applies the NotIn predicate on the "primary_entity" field. +func PrimaryEntityNotIn(vs ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldPrimaryEntity, vs...)) +} + +// PrimaryEntityGT applies the GT predicate on the "primary_entity" field. +func PrimaryEntityGT(v int) predicate.File { + return predicate.File(sql.FieldGT(FieldPrimaryEntity, v)) +} + +// PrimaryEntityGTE applies the GTE predicate on the "primary_entity" field. +func PrimaryEntityGTE(v int) predicate.File { + return predicate.File(sql.FieldGTE(FieldPrimaryEntity, v)) +} + +// PrimaryEntityLT applies the LT predicate on the "primary_entity" field. +func PrimaryEntityLT(v int) predicate.File { + return predicate.File(sql.FieldLT(FieldPrimaryEntity, v)) +} + +// PrimaryEntityLTE applies the LTE predicate on the "primary_entity" field. +func PrimaryEntityLTE(v int) predicate.File { + return predicate.File(sql.FieldLTE(FieldPrimaryEntity, v)) +} + +// PrimaryEntityIsNil applies the IsNil predicate on the "primary_entity" field. +func PrimaryEntityIsNil() predicate.File { + return predicate.File(sql.FieldIsNull(FieldPrimaryEntity)) +} + +// PrimaryEntityNotNil applies the NotNil predicate on the "primary_entity" field. +func PrimaryEntityNotNil() predicate.File { + return predicate.File(sql.FieldNotNull(FieldPrimaryEntity)) +} + +// FileChildrenEQ applies the EQ predicate on the "file_children" field. +func FileChildrenEQ(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldFileChildren, v)) +} + +// FileChildrenNEQ applies the NEQ predicate on the "file_children" field. +func FileChildrenNEQ(v int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldFileChildren, v)) +} + +// FileChildrenIn applies the In predicate on the "file_children" field. +func FileChildrenIn(vs ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldFileChildren, vs...)) +} + +// FileChildrenNotIn applies the NotIn predicate on the "file_children" field. +func FileChildrenNotIn(vs ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldFileChildren, vs...)) +} + +// FileChildrenIsNil applies the IsNil predicate on the "file_children" field. +func FileChildrenIsNil() predicate.File { + return predicate.File(sql.FieldIsNull(FieldFileChildren)) +} + +// FileChildrenNotNil applies the NotNil predicate on the "file_children" field. +func FileChildrenNotNil() predicate.File { + return predicate.File(sql.FieldNotNull(FieldFileChildren)) +} + +// IsSymbolicEQ applies the EQ predicate on the "is_symbolic" field. +func IsSymbolicEQ(v bool) predicate.File { + return predicate.File(sql.FieldEQ(FieldIsSymbolic, v)) +} + +// IsSymbolicNEQ applies the NEQ predicate on the "is_symbolic" field. +func IsSymbolicNEQ(v bool) predicate.File { + return predicate.File(sql.FieldNEQ(FieldIsSymbolic, v)) +} + +// PropsIsNil applies the IsNil predicate on the "props" field. +func PropsIsNil() predicate.File { + return predicate.File(sql.FieldIsNull(FieldProps)) +} + +// PropsNotNil applies the NotNil predicate on the "props" field. +func PropsNotNil() predicate.File { + return predicate.File(sql.FieldNotNull(FieldProps)) +} + +// StoragePolicyFilesEQ applies the EQ predicate on the "storage_policy_files" field. +func StoragePolicyFilesEQ(v int) predicate.File { + return predicate.File(sql.FieldEQ(FieldStoragePolicyFiles, v)) +} + +// StoragePolicyFilesNEQ applies the NEQ predicate on the "storage_policy_files" field. +func StoragePolicyFilesNEQ(v int) predicate.File { + return predicate.File(sql.FieldNEQ(FieldStoragePolicyFiles, v)) +} + +// StoragePolicyFilesIn applies the In predicate on the "storage_policy_files" field. +func StoragePolicyFilesIn(vs ...int) predicate.File { + return predicate.File(sql.FieldIn(FieldStoragePolicyFiles, vs...)) +} + +// StoragePolicyFilesNotIn applies the NotIn predicate on the "storage_policy_files" field. +func StoragePolicyFilesNotIn(vs ...int) predicate.File { + return predicate.File(sql.FieldNotIn(FieldStoragePolicyFiles, vs...)) +} + +// StoragePolicyFilesIsNil applies the IsNil predicate on the "storage_policy_files" field. +func StoragePolicyFilesIsNil() predicate.File { + return predicate.File(sql.FieldIsNull(FieldStoragePolicyFiles)) +} + +// StoragePolicyFilesNotNil applies the NotNil predicate on the "storage_policy_files" field. +func StoragePolicyFilesNotNil() predicate.File { + return predicate.File(sql.FieldNotNull(FieldStoragePolicyFiles)) +} + +// HasOwner applies the HasEdge predicate on the "owner" edge. +func HasOwner() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates). +func HasOwnerWith(preds ...predicate.User) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newOwnerStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasStoragePolicies applies the HasEdge predicate on the "storage_policies" edge. +func HasStoragePolicies() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePoliciesTable, StoragePoliciesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasStoragePoliciesWith applies the HasEdge predicate on the "storage_policies" edge with a given conditions (other predicates). +func HasStoragePoliciesWith(preds ...predicate.StoragePolicy) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newStoragePoliciesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasParent applies the HasEdge predicate on the "parent" edge. +func HasParent() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, ParentTable, ParentColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasParentWith applies the HasEdge predicate on the "parent" edge with a given conditions (other predicates). +func HasParentWith(preds ...predicate.File) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newParentStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChildren applies the HasEdge predicate on the "children" edge. +func HasChildren() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChildrenTable, ChildrenColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChildrenWith applies the HasEdge predicate on the "children" edge with a given conditions (other predicates). +func HasChildrenWith(preds ...predicate.File) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newChildrenStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasMetadata applies the HasEdge predicate on the "metadata" edge. +func HasMetadata() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, MetadataTable, MetadataColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMetadataWith applies the HasEdge predicate on the "metadata" edge with a given conditions (other predicates). +func HasMetadataWith(preds ...predicate.Metadata) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newMetadataStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasEntities applies the HasEdge predicate on the "entities" edge. +func HasEntities() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, EntitiesTable, EntitiesPrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasEntitiesWith applies the HasEdge predicate on the "entities" edge with a given conditions (other predicates). +func HasEntitiesWith(preds ...predicate.Entity) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newEntitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasShares applies the HasEdge predicate on the "shares" edge. +func HasShares() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SharesTable, SharesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasSharesWith applies the HasEdge predicate on the "shares" edge with a given conditions (other predicates). +func HasSharesWith(preds ...predicate.Share) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newSharesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasDirectLinks applies the HasEdge predicate on the "direct_links" edge. +func HasDirectLinks() predicate.File { + return predicate.File(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DirectLinksTable, DirectLinksColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasDirectLinksWith applies the HasEdge predicate on the "direct_links" edge with a given conditions (other predicates). +func HasDirectLinksWith(preds ...predicate.DirectLink) predicate.File { + return predicate.File(func(s *sql.Selector) { + step := newDirectLinksStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.File) predicate.File { + return predicate.File(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.File) predicate.File { + return predicate.File(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.File) predicate.File { + return predicate.File(sql.NotPredicates(p)) +} diff --git a/ent/file_create.go b/ent/file_create.go new file mode 100644 index 00000000..a0330c3b --- /dev/null +++ b/ent/file_create.go @@ -0,0 +1,1509 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// FileCreate is the builder for creating a File entity. +type FileCreate struct { + config + mutation *FileMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (fc *FileCreate) SetCreatedAt(t time.Time) *FileCreate { + fc.mutation.SetCreatedAt(t) + return fc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (fc *FileCreate) SetNillableCreatedAt(t *time.Time) *FileCreate { + if t != nil { + fc.SetCreatedAt(*t) + } + return fc +} + +// SetUpdatedAt sets the "updated_at" field. +func (fc *FileCreate) SetUpdatedAt(t time.Time) *FileCreate { + fc.mutation.SetUpdatedAt(t) + return fc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (fc *FileCreate) SetNillableUpdatedAt(t *time.Time) *FileCreate { + if t != nil { + fc.SetUpdatedAt(*t) + } + return fc +} + +// SetDeletedAt sets the "deleted_at" field. +func (fc *FileCreate) SetDeletedAt(t time.Time) *FileCreate { + fc.mutation.SetDeletedAt(t) + return fc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (fc *FileCreate) SetNillableDeletedAt(t *time.Time) *FileCreate { + if t != nil { + fc.SetDeletedAt(*t) + } + return fc +} + +// SetType sets the "type" field. +func (fc *FileCreate) SetType(i int) *FileCreate { + fc.mutation.SetType(i) + return fc +} + +// SetName sets the "name" field. +func (fc *FileCreate) SetName(s string) *FileCreate { + fc.mutation.SetName(s) + return fc +} + +// SetOwnerID sets the "owner_id" field. +func (fc *FileCreate) SetOwnerID(i int) *FileCreate { + fc.mutation.SetOwnerID(i) + return fc +} + +// SetSize sets the "size" field. +func (fc *FileCreate) SetSize(i int64) *FileCreate { + fc.mutation.SetSize(i) + return fc +} + +// SetNillableSize sets the "size" field if the given value is not nil. +func (fc *FileCreate) SetNillableSize(i *int64) *FileCreate { + if i != nil { + fc.SetSize(*i) + } + return fc +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (fc *FileCreate) SetPrimaryEntity(i int) *FileCreate { + fc.mutation.SetPrimaryEntity(i) + return fc +} + +// SetNillablePrimaryEntity sets the "primary_entity" field if the given value is not nil. +func (fc *FileCreate) SetNillablePrimaryEntity(i *int) *FileCreate { + if i != nil { + fc.SetPrimaryEntity(*i) + } + return fc +} + +// SetFileChildren sets the "file_children" field. +func (fc *FileCreate) SetFileChildren(i int) *FileCreate { + fc.mutation.SetFileChildren(i) + return fc +} + +// SetNillableFileChildren sets the "file_children" field if the given value is not nil. +func (fc *FileCreate) SetNillableFileChildren(i *int) *FileCreate { + if i != nil { + fc.SetFileChildren(*i) + } + return fc +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (fc *FileCreate) SetIsSymbolic(b bool) *FileCreate { + fc.mutation.SetIsSymbolic(b) + return fc +} + +// SetNillableIsSymbolic sets the "is_symbolic" field if the given value is not nil. +func (fc *FileCreate) SetNillableIsSymbolic(b *bool) *FileCreate { + if b != nil { + fc.SetIsSymbolic(*b) + } + return fc +} + +// SetProps sets the "props" field. +func (fc *FileCreate) SetProps(tp *types.FileProps) *FileCreate { + fc.mutation.SetProps(tp) + return fc +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (fc *FileCreate) SetStoragePolicyFiles(i int) *FileCreate { + fc.mutation.SetStoragePolicyFiles(i) + return fc +} + +// SetNillableStoragePolicyFiles sets the "storage_policy_files" field if the given value is not nil. +func (fc *FileCreate) SetNillableStoragePolicyFiles(i *int) *FileCreate { + if i != nil { + fc.SetStoragePolicyFiles(*i) + } + return fc +} + +// SetOwner sets the "owner" edge to the User entity. +func (fc *FileCreate) SetOwner(u *User) *FileCreate { + return fc.SetOwnerID(u.ID) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (fc *FileCreate) SetStoragePoliciesID(id int) *FileCreate { + fc.mutation.SetStoragePoliciesID(id) + return fc +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (fc *FileCreate) SetNillableStoragePoliciesID(id *int) *FileCreate { + if id != nil { + fc = fc.SetStoragePoliciesID(*id) + } + return fc +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (fc *FileCreate) SetStoragePolicies(s *StoragePolicy) *FileCreate { + return fc.SetStoragePoliciesID(s.ID) +} + +// SetParentID sets the "parent" edge to the File entity by ID. +func (fc *FileCreate) SetParentID(id int) *FileCreate { + fc.mutation.SetParentID(id) + return fc +} + +// SetNillableParentID sets the "parent" edge to the File entity by ID if the given value is not nil. +func (fc *FileCreate) SetNillableParentID(id *int) *FileCreate { + if id != nil { + fc = fc.SetParentID(*id) + } + return fc +} + +// SetParent sets the "parent" edge to the File entity. +func (fc *FileCreate) SetParent(f *File) *FileCreate { + return fc.SetParentID(f.ID) +} + +// AddChildIDs adds the "children" edge to the File entity by IDs. +func (fc *FileCreate) AddChildIDs(ids ...int) *FileCreate { + fc.mutation.AddChildIDs(ids...) + return fc +} + +// AddChildren adds the "children" edges to the File entity. +func (fc *FileCreate) AddChildren(f ...*File) *FileCreate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return fc.AddChildIDs(ids...) +} + +// AddMetadatumIDs adds the "metadata" edge to the Metadata entity by IDs. +func (fc *FileCreate) AddMetadatumIDs(ids ...int) *FileCreate { + fc.mutation.AddMetadatumIDs(ids...) + return fc +} + +// AddMetadata adds the "metadata" edges to the Metadata entity. +func (fc *FileCreate) AddMetadata(m ...*Metadata) *FileCreate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return fc.AddMetadatumIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (fc *FileCreate) AddEntityIDs(ids ...int) *FileCreate { + fc.mutation.AddEntityIDs(ids...) + return fc +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (fc *FileCreate) AddEntities(e ...*Entity) *FileCreate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return fc.AddEntityIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (fc *FileCreate) AddShareIDs(ids ...int) *FileCreate { + fc.mutation.AddShareIDs(ids...) + return fc +} + +// AddShares adds the "shares" edges to the Share entity. +func (fc *FileCreate) AddShares(s ...*Share) *FileCreate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return fc.AddShareIDs(ids...) +} + +// AddDirectLinkIDs adds the "direct_links" edge to the DirectLink entity by IDs. +func (fc *FileCreate) AddDirectLinkIDs(ids ...int) *FileCreate { + fc.mutation.AddDirectLinkIDs(ids...) + return fc +} + +// AddDirectLinks adds the "direct_links" edges to the DirectLink entity. +func (fc *FileCreate) AddDirectLinks(d ...*DirectLink) *FileCreate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return fc.AddDirectLinkIDs(ids...) +} + +// Mutation returns the FileMutation object of the builder. +func (fc *FileCreate) Mutation() *FileMutation { + return fc.mutation +} + +// Save creates the File in the database. +func (fc *FileCreate) Save(ctx context.Context) (*File, error) { + if err := fc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, fc.sqlSave, fc.mutation, fc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (fc *FileCreate) SaveX(ctx context.Context) *File { + v, err := fc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (fc *FileCreate) Exec(ctx context.Context) error { + _, err := fc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (fc *FileCreate) ExecX(ctx context.Context) { + if err := fc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (fc *FileCreate) defaults() error { + if _, ok := fc.mutation.CreatedAt(); !ok { + if file.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized file.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := file.DefaultCreatedAt() + fc.mutation.SetCreatedAt(v) + } + if _, ok := fc.mutation.UpdatedAt(); !ok { + if file.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized file.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := file.DefaultUpdatedAt() + fc.mutation.SetUpdatedAt(v) + } + if _, ok := fc.mutation.Size(); !ok { + v := file.DefaultSize + fc.mutation.SetSize(v) + } + if _, ok := fc.mutation.IsSymbolic(); !ok { + v := file.DefaultIsSymbolic + fc.mutation.SetIsSymbolic(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (fc *FileCreate) check() error { + if _, ok := fc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "File.created_at"`)} + } + if _, ok := fc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "File.updated_at"`)} + } + if _, ok := fc.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "File.type"`)} + } + if _, ok := fc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "File.name"`)} + } + if _, ok := fc.mutation.OwnerID(); !ok { + return &ValidationError{Name: "owner_id", err: errors.New(`ent: missing required field "File.owner_id"`)} + } + if _, ok := fc.mutation.Size(); !ok { + return &ValidationError{Name: "size", err: errors.New(`ent: missing required field "File.size"`)} + } + if _, ok := fc.mutation.IsSymbolic(); !ok { + return &ValidationError{Name: "is_symbolic", err: errors.New(`ent: missing required field "File.is_symbolic"`)} + } + if _, ok := fc.mutation.OwnerID(); !ok { + return &ValidationError{Name: "owner", err: errors.New(`ent: missing required edge "File.owner"`)} + } + return nil +} + +func (fc *FileCreate) sqlSave(ctx context.Context) (*File, error) { + if err := fc.check(); err != nil { + return nil, err + } + _node, _spec := fc.createSpec() + if err := sqlgraph.CreateNode(ctx, fc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + fc.mutation.id = &_node.ID + fc.mutation.done = true + return _node, nil +} + +func (fc *FileCreate) createSpec() (*File, *sqlgraph.CreateSpec) { + var ( + _node = &File{config: fc.config} + _spec = sqlgraph.NewCreateSpec(file.Table, sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt)) + ) + + if id, ok := fc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = fc.conflict + if value, ok := fc.mutation.CreatedAt(); ok { + _spec.SetField(file.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := fc.mutation.UpdatedAt(); ok { + _spec.SetField(file.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := fc.mutation.DeletedAt(); ok { + _spec.SetField(file.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := fc.mutation.GetType(); ok { + _spec.SetField(file.FieldType, field.TypeInt, value) + _node.Type = value + } + if value, ok := fc.mutation.Name(); ok { + _spec.SetField(file.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := fc.mutation.Size(); ok { + _spec.SetField(file.FieldSize, field.TypeInt64, value) + _node.Size = value + } + if value, ok := fc.mutation.PrimaryEntity(); ok { + _spec.SetField(file.FieldPrimaryEntity, field.TypeInt, value) + _node.PrimaryEntity = value + } + if value, ok := fc.mutation.IsSymbolic(); ok { + _spec.SetField(file.FieldIsSymbolic, field.TypeBool, value) + _node.IsSymbolic = value + } + if value, ok := fc.mutation.Props(); ok { + _spec.SetField(file.FieldProps, field.TypeJSON, value) + _node.Props = value + } + if nodes := fc.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.OwnerTable, + Columns: []string{file.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.OwnerID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.StoragePoliciesTable, + Columns: []string{file.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.StoragePolicyFiles = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.ParentIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.ParentTable, + Columns: []string{file.ParentColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.FileChildren = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.ChildrenIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.MetadataIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := fc.mutation.DirectLinksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.File.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.FileUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (fc *FileCreate) OnConflict(opts ...sql.ConflictOption) *FileUpsertOne { + fc.conflict = opts + return &FileUpsertOne{ + create: fc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (fc *FileCreate) OnConflictColumns(columns ...string) *FileUpsertOne { + fc.conflict = append(fc.conflict, sql.ConflictColumns(columns...)) + return &FileUpsertOne{ + create: fc, + } +} + +type ( + // FileUpsertOne is the builder for "upsert"-ing + // one File node. + FileUpsertOne struct { + create *FileCreate + } + + // FileUpsert is the "OnConflict" setter. + FileUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *FileUpsert) SetUpdatedAt(v time.Time) *FileUpsert { + u.Set(file.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *FileUpsert) UpdateUpdatedAt() *FileUpsert { + u.SetExcluded(file.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *FileUpsert) SetDeletedAt(v time.Time) *FileUpsert { + u.Set(file.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *FileUpsert) UpdateDeletedAt() *FileUpsert { + u.SetExcluded(file.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *FileUpsert) ClearDeletedAt() *FileUpsert { + u.SetNull(file.FieldDeletedAt) + return u +} + +// SetType sets the "type" field. +func (u *FileUpsert) SetType(v int) *FileUpsert { + u.Set(file.FieldType, v) + return u +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *FileUpsert) UpdateType() *FileUpsert { + u.SetExcluded(file.FieldType) + return u +} + +// AddType adds v to the "type" field. +func (u *FileUpsert) AddType(v int) *FileUpsert { + u.Add(file.FieldType, v) + return u +} + +// SetName sets the "name" field. +func (u *FileUpsert) SetName(v string) *FileUpsert { + u.Set(file.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *FileUpsert) UpdateName() *FileUpsert { + u.SetExcluded(file.FieldName) + return u +} + +// SetOwnerID sets the "owner_id" field. +func (u *FileUpsert) SetOwnerID(v int) *FileUpsert { + u.Set(file.FieldOwnerID, v) + return u +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *FileUpsert) UpdateOwnerID() *FileUpsert { + u.SetExcluded(file.FieldOwnerID) + return u +} + +// SetSize sets the "size" field. +func (u *FileUpsert) SetSize(v int64) *FileUpsert { + u.Set(file.FieldSize, v) + return u +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *FileUpsert) UpdateSize() *FileUpsert { + u.SetExcluded(file.FieldSize) + return u +} + +// AddSize adds v to the "size" field. +func (u *FileUpsert) AddSize(v int64) *FileUpsert { + u.Add(file.FieldSize, v) + return u +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (u *FileUpsert) SetPrimaryEntity(v int) *FileUpsert { + u.Set(file.FieldPrimaryEntity, v) + return u +} + +// UpdatePrimaryEntity sets the "primary_entity" field to the value that was provided on create. +func (u *FileUpsert) UpdatePrimaryEntity() *FileUpsert { + u.SetExcluded(file.FieldPrimaryEntity) + return u +} + +// AddPrimaryEntity adds v to the "primary_entity" field. +func (u *FileUpsert) AddPrimaryEntity(v int) *FileUpsert { + u.Add(file.FieldPrimaryEntity, v) + return u +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (u *FileUpsert) ClearPrimaryEntity() *FileUpsert { + u.SetNull(file.FieldPrimaryEntity) + return u +} + +// SetFileChildren sets the "file_children" field. +func (u *FileUpsert) SetFileChildren(v int) *FileUpsert { + u.Set(file.FieldFileChildren, v) + return u +} + +// UpdateFileChildren sets the "file_children" field to the value that was provided on create. +func (u *FileUpsert) UpdateFileChildren() *FileUpsert { + u.SetExcluded(file.FieldFileChildren) + return u +} + +// ClearFileChildren clears the value of the "file_children" field. +func (u *FileUpsert) ClearFileChildren() *FileUpsert { + u.SetNull(file.FieldFileChildren) + return u +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (u *FileUpsert) SetIsSymbolic(v bool) *FileUpsert { + u.Set(file.FieldIsSymbolic, v) + return u +} + +// UpdateIsSymbolic sets the "is_symbolic" field to the value that was provided on create. +func (u *FileUpsert) UpdateIsSymbolic() *FileUpsert { + u.SetExcluded(file.FieldIsSymbolic) + return u +} + +// SetProps sets the "props" field. +func (u *FileUpsert) SetProps(v *types.FileProps) *FileUpsert { + u.Set(file.FieldProps, v) + return u +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *FileUpsert) UpdateProps() *FileUpsert { + u.SetExcluded(file.FieldProps) + return u +} + +// ClearProps clears the value of the "props" field. +func (u *FileUpsert) ClearProps() *FileUpsert { + u.SetNull(file.FieldProps) + return u +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (u *FileUpsert) SetStoragePolicyFiles(v int) *FileUpsert { + u.Set(file.FieldStoragePolicyFiles, v) + return u +} + +// UpdateStoragePolicyFiles sets the "storage_policy_files" field to the value that was provided on create. +func (u *FileUpsert) UpdateStoragePolicyFiles() *FileUpsert { + u.SetExcluded(file.FieldStoragePolicyFiles) + return u +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (u *FileUpsert) ClearStoragePolicyFiles() *FileUpsert { + u.SetNull(file.FieldStoragePolicyFiles) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *FileUpsertOne) UpdateNewValues() *FileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(file.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *FileUpsertOne) Ignore() *FileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *FileUpsertOne) DoNothing() *FileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the FileCreate.OnConflict +// documentation for more info. +func (u *FileUpsertOne) Update(set func(*FileUpsert)) *FileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&FileUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *FileUpsertOne) SetUpdatedAt(v time.Time) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateUpdatedAt() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *FileUpsertOne) SetDeletedAt(v time.Time) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateDeletedAt() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *FileUpsertOne) ClearDeletedAt() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *FileUpsertOne) SetType(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetType(v) + }) +} + +// AddType adds v to the "type" field. +func (u *FileUpsertOne) AddType(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.AddType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateType() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateType() + }) +} + +// SetName sets the "name" field. +func (u *FileUpsertOne) SetName(v string) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateName() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateName() + }) +} + +// SetOwnerID sets the "owner_id" field. +func (u *FileUpsertOne) SetOwnerID(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetOwnerID(v) + }) +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateOwnerID() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateOwnerID() + }) +} + +// SetSize sets the "size" field. +func (u *FileUpsertOne) SetSize(v int64) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetSize(v) + }) +} + +// AddSize adds v to the "size" field. +func (u *FileUpsertOne) AddSize(v int64) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.AddSize(v) + }) +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateSize() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateSize() + }) +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (u *FileUpsertOne) SetPrimaryEntity(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetPrimaryEntity(v) + }) +} + +// AddPrimaryEntity adds v to the "primary_entity" field. +func (u *FileUpsertOne) AddPrimaryEntity(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.AddPrimaryEntity(v) + }) +} + +// UpdatePrimaryEntity sets the "primary_entity" field to the value that was provided on create. +func (u *FileUpsertOne) UpdatePrimaryEntity() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdatePrimaryEntity() + }) +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (u *FileUpsertOne) ClearPrimaryEntity() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.ClearPrimaryEntity() + }) +} + +// SetFileChildren sets the "file_children" field. +func (u *FileUpsertOne) SetFileChildren(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetFileChildren(v) + }) +} + +// UpdateFileChildren sets the "file_children" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateFileChildren() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateFileChildren() + }) +} + +// ClearFileChildren clears the value of the "file_children" field. +func (u *FileUpsertOne) ClearFileChildren() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.ClearFileChildren() + }) +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (u *FileUpsertOne) SetIsSymbolic(v bool) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetIsSymbolic(v) + }) +} + +// UpdateIsSymbolic sets the "is_symbolic" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateIsSymbolic() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateIsSymbolic() + }) +} + +// SetProps sets the "props" field. +func (u *FileUpsertOne) SetProps(v *types.FileProps) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetProps(v) + }) +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateProps() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateProps() + }) +} + +// ClearProps clears the value of the "props" field. +func (u *FileUpsertOne) ClearProps() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.ClearProps() + }) +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (u *FileUpsertOne) SetStoragePolicyFiles(v int) *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.SetStoragePolicyFiles(v) + }) +} + +// UpdateStoragePolicyFiles sets the "storage_policy_files" field to the value that was provided on create. +func (u *FileUpsertOne) UpdateStoragePolicyFiles() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.UpdateStoragePolicyFiles() + }) +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (u *FileUpsertOne) ClearStoragePolicyFiles() *FileUpsertOne { + return u.Update(func(s *FileUpsert) { + s.ClearStoragePolicyFiles() + }) +} + +// Exec executes the query. +func (u *FileUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for FileCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *FileUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *FileUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *FileUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *FileCreate) SetRawID(t int) *FileCreate { + m.mutation.SetRawID(t) + return m +} + +// FileCreateBulk is the builder for creating many File entities in bulk. +type FileCreateBulk struct { + config + err error + builders []*FileCreate + conflict []sql.ConflictOption +} + +// Save creates the File entities in the database. +func (fcb *FileCreateBulk) Save(ctx context.Context) ([]*File, error) { + if fcb.err != nil { + return nil, fcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(fcb.builders)) + nodes := make([]*File, len(fcb.builders)) + mutators := make([]Mutator, len(fcb.builders)) + for i := range fcb.builders { + func(i int, root context.Context) { + builder := fcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*FileMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, fcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = fcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, fcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, fcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (fcb *FileCreateBulk) SaveX(ctx context.Context) []*File { + v, err := fcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (fcb *FileCreateBulk) Exec(ctx context.Context) error { + _, err := fcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (fcb *FileCreateBulk) ExecX(ctx context.Context) { + if err := fcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.File.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.FileUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (fcb *FileCreateBulk) OnConflict(opts ...sql.ConflictOption) *FileUpsertBulk { + fcb.conflict = opts + return &FileUpsertBulk{ + create: fcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (fcb *FileCreateBulk) OnConflictColumns(columns ...string) *FileUpsertBulk { + fcb.conflict = append(fcb.conflict, sql.ConflictColumns(columns...)) + return &FileUpsertBulk{ + create: fcb, + } +} + +// FileUpsertBulk is the builder for "upsert"-ing +// a bulk of File nodes. +type FileUpsertBulk struct { + create *FileCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *FileUpsertBulk) UpdateNewValues() *FileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(file.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.File.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *FileUpsertBulk) Ignore() *FileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *FileUpsertBulk) DoNothing() *FileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the FileCreateBulk.OnConflict +// documentation for more info. +func (u *FileUpsertBulk) Update(set func(*FileUpsert)) *FileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&FileUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *FileUpsertBulk) SetUpdatedAt(v time.Time) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateUpdatedAt() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *FileUpsertBulk) SetDeletedAt(v time.Time) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateDeletedAt() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *FileUpsertBulk) ClearDeletedAt() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *FileUpsertBulk) SetType(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetType(v) + }) +} + +// AddType adds v to the "type" field. +func (u *FileUpsertBulk) AddType(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.AddType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateType() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateType() + }) +} + +// SetName sets the "name" field. +func (u *FileUpsertBulk) SetName(v string) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateName() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateName() + }) +} + +// SetOwnerID sets the "owner_id" field. +func (u *FileUpsertBulk) SetOwnerID(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetOwnerID(v) + }) +} + +// UpdateOwnerID sets the "owner_id" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateOwnerID() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateOwnerID() + }) +} + +// SetSize sets the "size" field. +func (u *FileUpsertBulk) SetSize(v int64) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetSize(v) + }) +} + +// AddSize adds v to the "size" field. +func (u *FileUpsertBulk) AddSize(v int64) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.AddSize(v) + }) +} + +// UpdateSize sets the "size" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateSize() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateSize() + }) +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (u *FileUpsertBulk) SetPrimaryEntity(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetPrimaryEntity(v) + }) +} + +// AddPrimaryEntity adds v to the "primary_entity" field. +func (u *FileUpsertBulk) AddPrimaryEntity(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.AddPrimaryEntity(v) + }) +} + +// UpdatePrimaryEntity sets the "primary_entity" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdatePrimaryEntity() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdatePrimaryEntity() + }) +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (u *FileUpsertBulk) ClearPrimaryEntity() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.ClearPrimaryEntity() + }) +} + +// SetFileChildren sets the "file_children" field. +func (u *FileUpsertBulk) SetFileChildren(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetFileChildren(v) + }) +} + +// UpdateFileChildren sets the "file_children" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateFileChildren() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateFileChildren() + }) +} + +// ClearFileChildren clears the value of the "file_children" field. +func (u *FileUpsertBulk) ClearFileChildren() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.ClearFileChildren() + }) +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (u *FileUpsertBulk) SetIsSymbolic(v bool) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetIsSymbolic(v) + }) +} + +// UpdateIsSymbolic sets the "is_symbolic" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateIsSymbolic() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateIsSymbolic() + }) +} + +// SetProps sets the "props" field. +func (u *FileUpsertBulk) SetProps(v *types.FileProps) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetProps(v) + }) +} + +// UpdateProps sets the "props" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateProps() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateProps() + }) +} + +// ClearProps clears the value of the "props" field. +func (u *FileUpsertBulk) ClearProps() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.ClearProps() + }) +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (u *FileUpsertBulk) SetStoragePolicyFiles(v int) *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.SetStoragePolicyFiles(v) + }) +} + +// UpdateStoragePolicyFiles sets the "storage_policy_files" field to the value that was provided on create. +func (u *FileUpsertBulk) UpdateStoragePolicyFiles() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.UpdateStoragePolicyFiles() + }) +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (u *FileUpsertBulk) ClearStoragePolicyFiles() *FileUpsertBulk { + return u.Update(func(s *FileUpsert) { + s.ClearStoragePolicyFiles() + }) +} + +// Exec executes the query. +func (u *FileUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the FileCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for FileCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *FileUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/file_delete.go b/ent/file_delete.go new file mode 100644 index 00000000..f9b75ca3 --- /dev/null +++ b/ent/file_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// FileDelete is the builder for deleting a File entity. +type FileDelete struct { + config + hooks []Hook + mutation *FileMutation +} + +// Where appends a list predicates to the FileDelete builder. +func (fd *FileDelete) Where(ps ...predicate.File) *FileDelete { + fd.mutation.Where(ps...) + return fd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (fd *FileDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, fd.sqlExec, fd.mutation, fd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (fd *FileDelete) ExecX(ctx context.Context) int { + n, err := fd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (fd *FileDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(file.Table, sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt)) + if ps := fd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, fd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + fd.mutation.done = true + return affected, err +} + +// FileDeleteOne is the builder for deleting a single File entity. +type FileDeleteOne struct { + fd *FileDelete +} + +// Where appends a list predicates to the FileDelete builder. +func (fdo *FileDeleteOne) Where(ps ...predicate.File) *FileDeleteOne { + fdo.fd.mutation.Where(ps...) + return fdo +} + +// Exec executes the deletion query. +func (fdo *FileDeleteOne) Exec(ctx context.Context) error { + n, err := fdo.fd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{file.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (fdo *FileDeleteOne) ExecX(ctx context.Context) { + if err := fdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/file_query.go b/ent/file_query.go new file mode 100644 index 00000000..a978798f --- /dev/null +++ b/ent/file_query.go @@ -0,0 +1,1156 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// FileQuery is the builder for querying File entities. +type FileQuery struct { + config + ctx *QueryContext + order []file.OrderOption + inters []Interceptor + predicates []predicate.File + withOwner *UserQuery + withStoragePolicies *StoragePolicyQuery + withParent *FileQuery + withChildren *FileQuery + withMetadata *MetadataQuery + withEntities *EntityQuery + withShares *ShareQuery + withDirectLinks *DirectLinkQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the FileQuery builder. +func (fq *FileQuery) Where(ps ...predicate.File) *FileQuery { + fq.predicates = append(fq.predicates, ps...) + return fq +} + +// Limit the number of records to be returned by this query. +func (fq *FileQuery) Limit(limit int) *FileQuery { + fq.ctx.Limit = &limit + return fq +} + +// Offset to start from. +func (fq *FileQuery) Offset(offset int) *FileQuery { + fq.ctx.Offset = &offset + return fq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (fq *FileQuery) Unique(unique bool) *FileQuery { + fq.ctx.Unique = &unique + return fq +} + +// Order specifies how the records should be ordered. +func (fq *FileQuery) Order(o ...file.OrderOption) *FileQuery { + fq.order = append(fq.order, o...) + return fq +} + +// QueryOwner chains the current query on the "owner" edge. +func (fq *FileQuery) QueryOwner() *UserQuery { + query := (&UserClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.OwnerTable, file.OwnerColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryStoragePolicies chains the current query on the "storage_policies" edge. +func (fq *FileQuery) QueryStoragePolicies() *StoragePolicyQuery { + query := (&StoragePolicyClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.StoragePoliciesTable, file.StoragePoliciesColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryParent chains the current query on the "parent" edge. +func (fq *FileQuery) QueryParent() *FileQuery { + query := (&FileClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, file.ParentTable, file.ParentColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChildren chains the current query on the "children" edge. +func (fq *FileQuery) QueryChildren() *FileQuery { + query := (&FileClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.ChildrenTable, file.ChildrenColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryMetadata chains the current query on the "metadata" edge. +func (fq *FileQuery) QueryMetadata() *MetadataQuery { + query := (&MetadataClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(metadata.Table, metadata.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.MetadataTable, file.MetadataColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryEntities chains the current query on the "entities" edge. +func (fq *FileQuery) QueryEntities() *EntityQuery { + query := (&EntityClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, file.EntitiesTable, file.EntitiesPrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryShares chains the current query on the "shares" edge. +func (fq *FileQuery) QueryShares() *ShareQuery { + query := (&ShareClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(share.Table, share.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.SharesTable, file.SharesColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryDirectLinks chains the current query on the "direct_links" edge. +func (fq *FileQuery) QueryDirectLinks() *DirectLinkQuery { + query := (&DirectLinkClient{config: fq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := fq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(file.Table, file.FieldID, selector), + sqlgraph.To(directlink.Table, directlink.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, file.DirectLinksTable, file.DirectLinksColumn), + ) + fromU = sqlgraph.SetNeighbors(fq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first File entity from the query. +// Returns a *NotFoundError when no File was found. +func (fq *FileQuery) First(ctx context.Context) (*File, error) { + nodes, err := fq.Limit(1).All(setContextOp(ctx, fq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{file.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (fq *FileQuery) FirstX(ctx context.Context) *File { + node, err := fq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first File ID from the query. +// Returns a *NotFoundError when no File ID was found. +func (fq *FileQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = fq.Limit(1).IDs(setContextOp(ctx, fq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{file.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (fq *FileQuery) FirstIDX(ctx context.Context) int { + id, err := fq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single File entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one File entity is found. +// Returns a *NotFoundError when no File entities are found. +func (fq *FileQuery) Only(ctx context.Context) (*File, error) { + nodes, err := fq.Limit(2).All(setContextOp(ctx, fq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{file.Label} + default: + return nil, &NotSingularError{file.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (fq *FileQuery) OnlyX(ctx context.Context) *File { + node, err := fq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only File ID in the query. +// Returns a *NotSingularError when more than one File ID is found. +// Returns a *NotFoundError when no entities are found. +func (fq *FileQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = fq.Limit(2).IDs(setContextOp(ctx, fq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{file.Label} + default: + err = &NotSingularError{file.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (fq *FileQuery) OnlyIDX(ctx context.Context) int { + id, err := fq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Files. +func (fq *FileQuery) All(ctx context.Context) ([]*File, error) { + ctx = setContextOp(ctx, fq.ctx, "All") + if err := fq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*File, *FileQuery]() + return withInterceptors[[]*File](ctx, fq, qr, fq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (fq *FileQuery) AllX(ctx context.Context) []*File { + nodes, err := fq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of File IDs. +func (fq *FileQuery) IDs(ctx context.Context) (ids []int, err error) { + if fq.ctx.Unique == nil && fq.path != nil { + fq.Unique(true) + } + ctx = setContextOp(ctx, fq.ctx, "IDs") + if err = fq.Select(file.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (fq *FileQuery) IDsX(ctx context.Context) []int { + ids, err := fq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (fq *FileQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, fq.ctx, "Count") + if err := fq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, fq, querierCount[*FileQuery](), fq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (fq *FileQuery) CountX(ctx context.Context) int { + count, err := fq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (fq *FileQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, fq.ctx, "Exist") + switch _, err := fq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (fq *FileQuery) ExistX(ctx context.Context) bool { + exist, err := fq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the FileQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (fq *FileQuery) Clone() *FileQuery { + if fq == nil { + return nil + } + return &FileQuery{ + config: fq.config, + ctx: fq.ctx.Clone(), + order: append([]file.OrderOption{}, fq.order...), + inters: append([]Interceptor{}, fq.inters...), + predicates: append([]predicate.File{}, fq.predicates...), + withOwner: fq.withOwner.Clone(), + withStoragePolicies: fq.withStoragePolicies.Clone(), + withParent: fq.withParent.Clone(), + withChildren: fq.withChildren.Clone(), + withMetadata: fq.withMetadata.Clone(), + withEntities: fq.withEntities.Clone(), + withShares: fq.withShares.Clone(), + withDirectLinks: fq.withDirectLinks.Clone(), + // clone intermediate query. + sql: fq.sql.Clone(), + path: fq.path, + } +} + +// WithOwner tells the query-builder to eager-load the nodes that are connected to +// the "owner" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithOwner(opts ...func(*UserQuery)) *FileQuery { + query := (&UserClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withOwner = query + return fq +} + +// WithStoragePolicies tells the query-builder to eager-load the nodes that are connected to +// the "storage_policies" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithStoragePolicies(opts ...func(*StoragePolicyQuery)) *FileQuery { + query := (&StoragePolicyClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withStoragePolicies = query + return fq +} + +// WithParent tells the query-builder to eager-load the nodes that are connected to +// the "parent" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithParent(opts ...func(*FileQuery)) *FileQuery { + query := (&FileClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withParent = query + return fq +} + +// WithChildren tells the query-builder to eager-load the nodes that are connected to +// the "children" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithChildren(opts ...func(*FileQuery)) *FileQuery { + query := (&FileClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withChildren = query + return fq +} + +// WithMetadata tells the query-builder to eager-load the nodes that are connected to +// the "metadata" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithMetadata(opts ...func(*MetadataQuery)) *FileQuery { + query := (&MetadataClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withMetadata = query + return fq +} + +// WithEntities tells the query-builder to eager-load the nodes that are connected to +// the "entities" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithEntities(opts ...func(*EntityQuery)) *FileQuery { + query := (&EntityClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withEntities = query + return fq +} + +// WithShares tells the query-builder to eager-load the nodes that are connected to +// the "shares" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithShares(opts ...func(*ShareQuery)) *FileQuery { + query := (&ShareClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withShares = query + return fq +} + +// WithDirectLinks tells the query-builder to eager-load the nodes that are connected to +// the "direct_links" edge. The optional arguments are used to configure the query builder of the edge. +func (fq *FileQuery) WithDirectLinks(opts ...func(*DirectLinkQuery)) *FileQuery { + query := (&DirectLinkClient{config: fq.config}).Query() + for _, opt := range opts { + opt(query) + } + fq.withDirectLinks = query + return fq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.File.Query(). +// GroupBy(file.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (fq *FileQuery) GroupBy(field string, fields ...string) *FileGroupBy { + fq.ctx.Fields = append([]string{field}, fields...) + grbuild := &FileGroupBy{build: fq} + grbuild.flds = &fq.ctx.Fields + grbuild.label = file.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.File.Query(). +// Select(file.FieldCreatedAt). +// Scan(ctx, &v) +func (fq *FileQuery) Select(fields ...string) *FileSelect { + fq.ctx.Fields = append(fq.ctx.Fields, fields...) + sbuild := &FileSelect{FileQuery: fq} + sbuild.label = file.Label + sbuild.flds, sbuild.scan = &fq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a FileSelect configured with the given aggregations. +func (fq *FileQuery) Aggregate(fns ...AggregateFunc) *FileSelect { + return fq.Select().Aggregate(fns...) +} + +func (fq *FileQuery) prepareQuery(ctx context.Context) error { + for _, inter := range fq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, fq); err != nil { + return err + } + } + } + for _, f := range fq.ctx.Fields { + if !file.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if fq.path != nil { + prev, err := fq.path(ctx) + if err != nil { + return err + } + fq.sql = prev + } + return nil +} + +func (fq *FileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*File, error) { + var ( + nodes = []*File{} + _spec = fq.querySpec() + loadedTypes = [8]bool{ + fq.withOwner != nil, + fq.withStoragePolicies != nil, + fq.withParent != nil, + fq.withChildren != nil, + fq.withMetadata != nil, + fq.withEntities != nil, + fq.withShares != nil, + fq.withDirectLinks != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*File).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &File{config: fq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, fq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := fq.withOwner; query != nil { + if err := fq.loadOwner(ctx, query, nodes, nil, + func(n *File, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err + } + } + if query := fq.withStoragePolicies; query != nil { + if err := fq.loadStoragePolicies(ctx, query, nodes, nil, + func(n *File, e *StoragePolicy) { n.Edges.StoragePolicies = e }); err != nil { + return nil, err + } + } + if query := fq.withParent; query != nil { + if err := fq.loadParent(ctx, query, nodes, nil, + func(n *File, e *File) { n.Edges.Parent = e }); err != nil { + return nil, err + } + } + if query := fq.withChildren; query != nil { + if err := fq.loadChildren(ctx, query, nodes, + func(n *File) { n.Edges.Children = []*File{} }, + func(n *File, e *File) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err + } + } + if query := fq.withMetadata; query != nil { + if err := fq.loadMetadata(ctx, query, nodes, + func(n *File) { n.Edges.Metadata = []*Metadata{} }, + func(n *File, e *Metadata) { n.Edges.Metadata = append(n.Edges.Metadata, e) }); err != nil { + return nil, err + } + } + if query := fq.withEntities; query != nil { + if err := fq.loadEntities(ctx, query, nodes, + func(n *File) { n.Edges.Entities = []*Entity{} }, + func(n *File, e *Entity) { n.Edges.Entities = append(n.Edges.Entities, e) }); err != nil { + return nil, err + } + } + if query := fq.withShares; query != nil { + if err := fq.loadShares(ctx, query, nodes, + func(n *File) { n.Edges.Shares = []*Share{} }, + func(n *File, e *Share) { n.Edges.Shares = append(n.Edges.Shares, e) }); err != nil { + return nil, err + } + } + if query := fq.withDirectLinks; query != nil { + if err := fq.loadDirectLinks(ctx, query, nodes, + func(n *File) { n.Edges.DirectLinks = []*DirectLink{} }, + func(n *File, e *DirectLink) { n.Edges.DirectLinks = append(n.Edges.DirectLinks, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (fq *FileQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*File, init func(*File), assign func(*File, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + fk := nodes[i].OwnerID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (fq *FileQuery) loadStoragePolicies(ctx context.Context, query *StoragePolicyQuery, nodes []*File, init func(*File), assign func(*File, *StoragePolicy)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + fk := nodes[i].StoragePolicyFiles + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(storagepolicy.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "storage_policy_files" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (fq *FileQuery) loadParent(ctx context.Context, query *FileQuery, nodes []*File, init func(*File), assign func(*File, *File)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + fk := nodes[i].FileChildren + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(file.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_children" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (fq *FileQuery) loadChildren(ctx context.Context, query *FileQuery, nodes []*File, init func(*File), assign func(*File, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(file.FieldFileChildren) + } + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(file.ChildrenColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.FileChildren + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "file_children" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (fq *FileQuery) loadMetadata(ctx context.Context, query *MetadataQuery, nodes []*File, init func(*File), assign func(*File, *Metadata)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(metadata.FieldFileID) + } + query.Where(predicate.Metadata(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(file.MetadataColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.FileID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "file_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (fq *FileQuery) loadEntities(ctx context.Context, query *EntityQuery, nodes []*File, init func(*File), assign func(*File, *Entity)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*File) + nids := make(map[int]map[*File]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(file.EntitiesTable) + s.Join(joinT).On(s.C(entity.FieldID), joinT.C(file.EntitiesPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(file.EntitiesPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(file.EntitiesPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*File]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Entity](ctx, query, qr, query.inters) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "entities" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (fq *FileQuery) loadShares(ctx context.Context, query *ShareQuery, nodes []*File, init func(*File), assign func(*File, *Share)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Share(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(file.SharesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.file_shares + if fk == nil { + return fmt.Errorf(`foreign-key "file_shares" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "file_shares" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (fq *FileQuery) loadDirectLinks(ctx context.Context, query *DirectLinkQuery, nodes []*File, init func(*File), assign func(*File, *DirectLink)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(directlink.FieldFileID) + } + query.Where(predicate.DirectLink(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(file.DirectLinksColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.FileID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "file_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (fq *FileQuery) sqlCount(ctx context.Context) (int, error) { + _spec := fq.querySpec() + _spec.Node.Columns = fq.ctx.Fields + if len(fq.ctx.Fields) > 0 { + _spec.Unique = fq.ctx.Unique != nil && *fq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, fq.driver, _spec) +} + +func (fq *FileQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(file.Table, file.Columns, sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt)) + _spec.From = fq.sql + if unique := fq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if fq.path != nil { + _spec.Unique = true + } + if fields := fq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, file.FieldID) + for i := range fields { + if fields[i] != file.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if fq.withOwner != nil { + _spec.Node.AddColumnOnce(file.FieldOwnerID) + } + if fq.withStoragePolicies != nil { + _spec.Node.AddColumnOnce(file.FieldStoragePolicyFiles) + } + if fq.withParent != nil { + _spec.Node.AddColumnOnce(file.FieldFileChildren) + } + } + if ps := fq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := fq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := fq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := fq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (fq *FileQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(fq.driver.Dialect()) + t1 := builder.Table(file.Table) + columns := fq.ctx.Fields + if len(columns) == 0 { + columns = file.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if fq.sql != nil { + selector = fq.sql + selector.Select(selector.Columns(columns...)...) + } + if fq.ctx.Unique != nil && *fq.ctx.Unique { + selector.Distinct() + } + for _, p := range fq.predicates { + p(selector) + } + for _, p := range fq.order { + p(selector) + } + if offset := fq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := fq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// FileGroupBy is the group-by builder for File entities. +type FileGroupBy struct { + selector + build *FileQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (fgb *FileGroupBy) Aggregate(fns ...AggregateFunc) *FileGroupBy { + fgb.fns = append(fgb.fns, fns...) + return fgb +} + +// Scan applies the selector query and scans the result into the given value. +func (fgb *FileGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, fgb.build.ctx, "GroupBy") + if err := fgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*FileQuery, *FileGroupBy](ctx, fgb.build, fgb, fgb.build.inters, v) +} + +func (fgb *FileGroupBy) sqlScan(ctx context.Context, root *FileQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(fgb.fns)) + for _, fn := range fgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*fgb.flds)+len(fgb.fns)) + for _, f := range *fgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*fgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := fgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// FileSelect is the builder for selecting fields of File entities. +type FileSelect struct { + *FileQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (fs *FileSelect) Aggregate(fns ...AggregateFunc) *FileSelect { + fs.fns = append(fs.fns, fns...) + return fs +} + +// Scan applies the selector query and scans the result into the given value. +func (fs *FileSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, fs.ctx, "Select") + if err := fs.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*FileQuery, *FileSelect](ctx, fs.FileQuery, fs, fs.inters, v) +} + +func (fs *FileSelect) sqlScan(ctx context.Context, root *FileQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(fs.fns)) + for _, fn := range fs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*fs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := fs.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/file_update.go b/ent/file_update.go new file mode 100644 index 00000000..4f913962 --- /dev/null +++ b/ent/file_update.go @@ -0,0 +1,1803 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// FileUpdate is the builder for updating File entities. +type FileUpdate struct { + config + hooks []Hook + mutation *FileMutation +} + +// Where appends a list predicates to the FileUpdate builder. +func (fu *FileUpdate) Where(ps ...predicate.File) *FileUpdate { + fu.mutation.Where(ps...) + return fu +} + +// SetUpdatedAt sets the "updated_at" field. +func (fu *FileUpdate) SetUpdatedAt(t time.Time) *FileUpdate { + fu.mutation.SetUpdatedAt(t) + return fu +} + +// SetDeletedAt sets the "deleted_at" field. +func (fu *FileUpdate) SetDeletedAt(t time.Time) *FileUpdate { + fu.mutation.SetDeletedAt(t) + return fu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (fu *FileUpdate) SetNillableDeletedAt(t *time.Time) *FileUpdate { + if t != nil { + fu.SetDeletedAt(*t) + } + return fu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (fu *FileUpdate) ClearDeletedAt() *FileUpdate { + fu.mutation.ClearDeletedAt() + return fu +} + +// SetType sets the "type" field. +func (fu *FileUpdate) SetType(i int) *FileUpdate { + fu.mutation.ResetType() + fu.mutation.SetType(i) + return fu +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (fu *FileUpdate) SetNillableType(i *int) *FileUpdate { + if i != nil { + fu.SetType(*i) + } + return fu +} + +// AddType adds i to the "type" field. +func (fu *FileUpdate) AddType(i int) *FileUpdate { + fu.mutation.AddType(i) + return fu +} + +// SetName sets the "name" field. +func (fu *FileUpdate) SetName(s string) *FileUpdate { + fu.mutation.SetName(s) + return fu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (fu *FileUpdate) SetNillableName(s *string) *FileUpdate { + if s != nil { + fu.SetName(*s) + } + return fu +} + +// SetOwnerID sets the "owner_id" field. +func (fu *FileUpdate) SetOwnerID(i int) *FileUpdate { + fu.mutation.SetOwnerID(i) + return fu +} + +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (fu *FileUpdate) SetNillableOwnerID(i *int) *FileUpdate { + if i != nil { + fu.SetOwnerID(*i) + } + return fu +} + +// SetSize sets the "size" field. +func (fu *FileUpdate) SetSize(i int64) *FileUpdate { + fu.mutation.ResetSize() + fu.mutation.SetSize(i) + return fu +} + +// SetNillableSize sets the "size" field if the given value is not nil. +func (fu *FileUpdate) SetNillableSize(i *int64) *FileUpdate { + if i != nil { + fu.SetSize(*i) + } + return fu +} + +// AddSize adds i to the "size" field. +func (fu *FileUpdate) AddSize(i int64) *FileUpdate { + fu.mutation.AddSize(i) + return fu +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (fu *FileUpdate) SetPrimaryEntity(i int) *FileUpdate { + fu.mutation.ResetPrimaryEntity() + fu.mutation.SetPrimaryEntity(i) + return fu +} + +// SetNillablePrimaryEntity sets the "primary_entity" field if the given value is not nil. +func (fu *FileUpdate) SetNillablePrimaryEntity(i *int) *FileUpdate { + if i != nil { + fu.SetPrimaryEntity(*i) + } + return fu +} + +// AddPrimaryEntity adds i to the "primary_entity" field. +func (fu *FileUpdate) AddPrimaryEntity(i int) *FileUpdate { + fu.mutation.AddPrimaryEntity(i) + return fu +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (fu *FileUpdate) ClearPrimaryEntity() *FileUpdate { + fu.mutation.ClearPrimaryEntity() + return fu +} + +// SetFileChildren sets the "file_children" field. +func (fu *FileUpdate) SetFileChildren(i int) *FileUpdate { + fu.mutation.SetFileChildren(i) + return fu +} + +// SetNillableFileChildren sets the "file_children" field if the given value is not nil. +func (fu *FileUpdate) SetNillableFileChildren(i *int) *FileUpdate { + if i != nil { + fu.SetFileChildren(*i) + } + return fu +} + +// ClearFileChildren clears the value of the "file_children" field. +func (fu *FileUpdate) ClearFileChildren() *FileUpdate { + fu.mutation.ClearFileChildren() + return fu +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (fu *FileUpdate) SetIsSymbolic(b bool) *FileUpdate { + fu.mutation.SetIsSymbolic(b) + return fu +} + +// SetNillableIsSymbolic sets the "is_symbolic" field if the given value is not nil. +func (fu *FileUpdate) SetNillableIsSymbolic(b *bool) *FileUpdate { + if b != nil { + fu.SetIsSymbolic(*b) + } + return fu +} + +// SetProps sets the "props" field. +func (fu *FileUpdate) SetProps(tp *types.FileProps) *FileUpdate { + fu.mutation.SetProps(tp) + return fu +} + +// ClearProps clears the value of the "props" field. +func (fu *FileUpdate) ClearProps() *FileUpdate { + fu.mutation.ClearProps() + return fu +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (fu *FileUpdate) SetStoragePolicyFiles(i int) *FileUpdate { + fu.mutation.SetStoragePolicyFiles(i) + return fu +} + +// SetNillableStoragePolicyFiles sets the "storage_policy_files" field if the given value is not nil. +func (fu *FileUpdate) SetNillableStoragePolicyFiles(i *int) *FileUpdate { + if i != nil { + fu.SetStoragePolicyFiles(*i) + } + return fu +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (fu *FileUpdate) ClearStoragePolicyFiles() *FileUpdate { + fu.mutation.ClearStoragePolicyFiles() + return fu +} + +// SetOwner sets the "owner" edge to the User entity. +func (fu *FileUpdate) SetOwner(u *User) *FileUpdate { + return fu.SetOwnerID(u.ID) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (fu *FileUpdate) SetStoragePoliciesID(id int) *FileUpdate { + fu.mutation.SetStoragePoliciesID(id) + return fu +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (fu *FileUpdate) SetNillableStoragePoliciesID(id *int) *FileUpdate { + if id != nil { + fu = fu.SetStoragePoliciesID(*id) + } + return fu +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (fu *FileUpdate) SetStoragePolicies(s *StoragePolicy) *FileUpdate { + return fu.SetStoragePoliciesID(s.ID) +} + +// SetParentID sets the "parent" edge to the File entity by ID. +func (fu *FileUpdate) SetParentID(id int) *FileUpdate { + fu.mutation.SetParentID(id) + return fu +} + +// SetNillableParentID sets the "parent" edge to the File entity by ID if the given value is not nil. +func (fu *FileUpdate) SetNillableParentID(id *int) *FileUpdate { + if id != nil { + fu = fu.SetParentID(*id) + } + return fu +} + +// SetParent sets the "parent" edge to the File entity. +func (fu *FileUpdate) SetParent(f *File) *FileUpdate { + return fu.SetParentID(f.ID) +} + +// AddChildIDs adds the "children" edge to the File entity by IDs. +func (fu *FileUpdate) AddChildIDs(ids ...int) *FileUpdate { + fu.mutation.AddChildIDs(ids...) + return fu +} + +// AddChildren adds the "children" edges to the File entity. +func (fu *FileUpdate) AddChildren(f ...*File) *FileUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return fu.AddChildIDs(ids...) +} + +// AddMetadatumIDs adds the "metadata" edge to the Metadata entity by IDs. +func (fu *FileUpdate) AddMetadatumIDs(ids ...int) *FileUpdate { + fu.mutation.AddMetadatumIDs(ids...) + return fu +} + +// AddMetadata adds the "metadata" edges to the Metadata entity. +func (fu *FileUpdate) AddMetadata(m ...*Metadata) *FileUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return fu.AddMetadatumIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (fu *FileUpdate) AddEntityIDs(ids ...int) *FileUpdate { + fu.mutation.AddEntityIDs(ids...) + return fu +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (fu *FileUpdate) AddEntities(e ...*Entity) *FileUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return fu.AddEntityIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (fu *FileUpdate) AddShareIDs(ids ...int) *FileUpdate { + fu.mutation.AddShareIDs(ids...) + return fu +} + +// AddShares adds the "shares" edges to the Share entity. +func (fu *FileUpdate) AddShares(s ...*Share) *FileUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return fu.AddShareIDs(ids...) +} + +// AddDirectLinkIDs adds the "direct_links" edge to the DirectLink entity by IDs. +func (fu *FileUpdate) AddDirectLinkIDs(ids ...int) *FileUpdate { + fu.mutation.AddDirectLinkIDs(ids...) + return fu +} + +// AddDirectLinks adds the "direct_links" edges to the DirectLink entity. +func (fu *FileUpdate) AddDirectLinks(d ...*DirectLink) *FileUpdate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return fu.AddDirectLinkIDs(ids...) +} + +// Mutation returns the FileMutation object of the builder. +func (fu *FileUpdate) Mutation() *FileMutation { + return fu.mutation +} + +// ClearOwner clears the "owner" edge to the User entity. +func (fu *FileUpdate) ClearOwner() *FileUpdate { + fu.mutation.ClearOwner() + return fu +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (fu *FileUpdate) ClearStoragePolicies() *FileUpdate { + fu.mutation.ClearStoragePolicies() + return fu +} + +// ClearParent clears the "parent" edge to the File entity. +func (fu *FileUpdate) ClearParent() *FileUpdate { + fu.mutation.ClearParent() + return fu +} + +// ClearChildren clears all "children" edges to the File entity. +func (fu *FileUpdate) ClearChildren() *FileUpdate { + fu.mutation.ClearChildren() + return fu +} + +// RemoveChildIDs removes the "children" edge to File entities by IDs. +func (fu *FileUpdate) RemoveChildIDs(ids ...int) *FileUpdate { + fu.mutation.RemoveChildIDs(ids...) + return fu +} + +// RemoveChildren removes "children" edges to File entities. +func (fu *FileUpdate) RemoveChildren(f ...*File) *FileUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return fu.RemoveChildIDs(ids...) +} + +// ClearMetadata clears all "metadata" edges to the Metadata entity. +func (fu *FileUpdate) ClearMetadata() *FileUpdate { + fu.mutation.ClearMetadata() + return fu +} + +// RemoveMetadatumIDs removes the "metadata" edge to Metadata entities by IDs. +func (fu *FileUpdate) RemoveMetadatumIDs(ids ...int) *FileUpdate { + fu.mutation.RemoveMetadatumIDs(ids...) + return fu +} + +// RemoveMetadata removes "metadata" edges to Metadata entities. +func (fu *FileUpdate) RemoveMetadata(m ...*Metadata) *FileUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return fu.RemoveMetadatumIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (fu *FileUpdate) ClearEntities() *FileUpdate { + fu.mutation.ClearEntities() + return fu +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (fu *FileUpdate) RemoveEntityIDs(ids ...int) *FileUpdate { + fu.mutation.RemoveEntityIDs(ids...) + return fu +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (fu *FileUpdate) RemoveEntities(e ...*Entity) *FileUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return fu.RemoveEntityIDs(ids...) +} + +// ClearShares clears all "shares" edges to the Share entity. +func (fu *FileUpdate) ClearShares() *FileUpdate { + fu.mutation.ClearShares() + return fu +} + +// RemoveShareIDs removes the "shares" edge to Share entities by IDs. +func (fu *FileUpdate) RemoveShareIDs(ids ...int) *FileUpdate { + fu.mutation.RemoveShareIDs(ids...) + return fu +} + +// RemoveShares removes "shares" edges to Share entities. +func (fu *FileUpdate) RemoveShares(s ...*Share) *FileUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return fu.RemoveShareIDs(ids...) +} + +// ClearDirectLinks clears all "direct_links" edges to the DirectLink entity. +func (fu *FileUpdate) ClearDirectLinks() *FileUpdate { + fu.mutation.ClearDirectLinks() + return fu +} + +// RemoveDirectLinkIDs removes the "direct_links" edge to DirectLink entities by IDs. +func (fu *FileUpdate) RemoveDirectLinkIDs(ids ...int) *FileUpdate { + fu.mutation.RemoveDirectLinkIDs(ids...) + return fu +} + +// RemoveDirectLinks removes "direct_links" edges to DirectLink entities. +func (fu *FileUpdate) RemoveDirectLinks(d ...*DirectLink) *FileUpdate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return fu.RemoveDirectLinkIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (fu *FileUpdate) Save(ctx context.Context) (int, error) { + if err := fu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, fu.sqlSave, fu.mutation, fu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (fu *FileUpdate) SaveX(ctx context.Context) int { + affected, err := fu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (fu *FileUpdate) Exec(ctx context.Context) error { + _, err := fu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (fu *FileUpdate) ExecX(ctx context.Context) { + if err := fu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (fu *FileUpdate) defaults() error { + if _, ok := fu.mutation.UpdatedAt(); !ok { + if file.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized file.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := file.UpdateDefaultUpdatedAt() + fu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (fu *FileUpdate) check() error { + if _, ok := fu.mutation.OwnerID(); fu.mutation.OwnerCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "File.owner"`) + } + return nil +} + +func (fu *FileUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := fu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(file.Table, file.Columns, sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt)) + if ps := fu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := fu.mutation.UpdatedAt(); ok { + _spec.SetField(file.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := fu.mutation.DeletedAt(); ok { + _spec.SetField(file.FieldDeletedAt, field.TypeTime, value) + } + if fu.mutation.DeletedAtCleared() { + _spec.ClearField(file.FieldDeletedAt, field.TypeTime) + } + if value, ok := fu.mutation.GetType(); ok { + _spec.SetField(file.FieldType, field.TypeInt, value) + } + if value, ok := fu.mutation.AddedType(); ok { + _spec.AddField(file.FieldType, field.TypeInt, value) + } + if value, ok := fu.mutation.Name(); ok { + _spec.SetField(file.FieldName, field.TypeString, value) + } + if value, ok := fu.mutation.Size(); ok { + _spec.SetField(file.FieldSize, field.TypeInt64, value) + } + if value, ok := fu.mutation.AddedSize(); ok { + _spec.AddField(file.FieldSize, field.TypeInt64, value) + } + if value, ok := fu.mutation.PrimaryEntity(); ok { + _spec.SetField(file.FieldPrimaryEntity, field.TypeInt, value) + } + if value, ok := fu.mutation.AddedPrimaryEntity(); ok { + _spec.AddField(file.FieldPrimaryEntity, field.TypeInt, value) + } + if fu.mutation.PrimaryEntityCleared() { + _spec.ClearField(file.FieldPrimaryEntity, field.TypeInt) + } + if value, ok := fu.mutation.IsSymbolic(); ok { + _spec.SetField(file.FieldIsSymbolic, field.TypeBool, value) + } + if value, ok := fu.mutation.Props(); ok { + _spec.SetField(file.FieldProps, field.TypeJSON, value) + } + if fu.mutation.PropsCleared() { + _spec.ClearField(file.FieldProps, field.TypeJSON) + } + if fu.mutation.OwnerCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.OwnerTable, + Columns: []string{file.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.OwnerTable, + Columns: []string{file.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.StoragePoliciesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.StoragePoliciesTable, + Columns: []string{file.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.StoragePoliciesTable, + Columns: []string{file.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.ParentCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.ParentTable, + Columns: []string{file.ParentColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.ParentIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.ParentTable, + Columns: []string{file.ParentColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.ChildrenCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.RemovedChildrenIDs(); len(nodes) > 0 && !fu.mutation.ChildrenCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.ChildrenIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.MetadataCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.RemovedMetadataIDs(); len(nodes) > 0 && !fu.mutation.MetadataCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.MetadataIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !fu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.RemovedSharesIDs(); len(nodes) > 0 && !fu.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fu.mutation.DirectLinksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.RemovedDirectLinksIDs(); len(nodes) > 0 && !fu.mutation.DirectLinksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fu.mutation.DirectLinksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, fu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{file.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + fu.mutation.done = true + return n, nil +} + +// FileUpdateOne is the builder for updating a single File entity. +type FileUpdateOne struct { + config + fields []string + hooks []Hook + mutation *FileMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (fuo *FileUpdateOne) SetUpdatedAt(t time.Time) *FileUpdateOne { + fuo.mutation.SetUpdatedAt(t) + return fuo +} + +// SetDeletedAt sets the "deleted_at" field. +func (fuo *FileUpdateOne) SetDeletedAt(t time.Time) *FileUpdateOne { + fuo.mutation.SetDeletedAt(t) + return fuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableDeletedAt(t *time.Time) *FileUpdateOne { + if t != nil { + fuo.SetDeletedAt(*t) + } + return fuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (fuo *FileUpdateOne) ClearDeletedAt() *FileUpdateOne { + fuo.mutation.ClearDeletedAt() + return fuo +} + +// SetType sets the "type" field. +func (fuo *FileUpdateOne) SetType(i int) *FileUpdateOne { + fuo.mutation.ResetType() + fuo.mutation.SetType(i) + return fuo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableType(i *int) *FileUpdateOne { + if i != nil { + fuo.SetType(*i) + } + return fuo +} + +// AddType adds i to the "type" field. +func (fuo *FileUpdateOne) AddType(i int) *FileUpdateOne { + fuo.mutation.AddType(i) + return fuo +} + +// SetName sets the "name" field. +func (fuo *FileUpdateOne) SetName(s string) *FileUpdateOne { + fuo.mutation.SetName(s) + return fuo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableName(s *string) *FileUpdateOne { + if s != nil { + fuo.SetName(*s) + } + return fuo +} + +// SetOwnerID sets the "owner_id" field. +func (fuo *FileUpdateOne) SetOwnerID(i int) *FileUpdateOne { + fuo.mutation.SetOwnerID(i) + return fuo +} + +// SetNillableOwnerID sets the "owner_id" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableOwnerID(i *int) *FileUpdateOne { + if i != nil { + fuo.SetOwnerID(*i) + } + return fuo +} + +// SetSize sets the "size" field. +func (fuo *FileUpdateOne) SetSize(i int64) *FileUpdateOne { + fuo.mutation.ResetSize() + fuo.mutation.SetSize(i) + return fuo +} + +// SetNillableSize sets the "size" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableSize(i *int64) *FileUpdateOne { + if i != nil { + fuo.SetSize(*i) + } + return fuo +} + +// AddSize adds i to the "size" field. +func (fuo *FileUpdateOne) AddSize(i int64) *FileUpdateOne { + fuo.mutation.AddSize(i) + return fuo +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (fuo *FileUpdateOne) SetPrimaryEntity(i int) *FileUpdateOne { + fuo.mutation.ResetPrimaryEntity() + fuo.mutation.SetPrimaryEntity(i) + return fuo +} + +// SetNillablePrimaryEntity sets the "primary_entity" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillablePrimaryEntity(i *int) *FileUpdateOne { + if i != nil { + fuo.SetPrimaryEntity(*i) + } + return fuo +} + +// AddPrimaryEntity adds i to the "primary_entity" field. +func (fuo *FileUpdateOne) AddPrimaryEntity(i int) *FileUpdateOne { + fuo.mutation.AddPrimaryEntity(i) + return fuo +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (fuo *FileUpdateOne) ClearPrimaryEntity() *FileUpdateOne { + fuo.mutation.ClearPrimaryEntity() + return fuo +} + +// SetFileChildren sets the "file_children" field. +func (fuo *FileUpdateOne) SetFileChildren(i int) *FileUpdateOne { + fuo.mutation.SetFileChildren(i) + return fuo +} + +// SetNillableFileChildren sets the "file_children" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableFileChildren(i *int) *FileUpdateOne { + if i != nil { + fuo.SetFileChildren(*i) + } + return fuo +} + +// ClearFileChildren clears the value of the "file_children" field. +func (fuo *FileUpdateOne) ClearFileChildren() *FileUpdateOne { + fuo.mutation.ClearFileChildren() + return fuo +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (fuo *FileUpdateOne) SetIsSymbolic(b bool) *FileUpdateOne { + fuo.mutation.SetIsSymbolic(b) + return fuo +} + +// SetNillableIsSymbolic sets the "is_symbolic" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableIsSymbolic(b *bool) *FileUpdateOne { + if b != nil { + fuo.SetIsSymbolic(*b) + } + return fuo +} + +// SetProps sets the "props" field. +func (fuo *FileUpdateOne) SetProps(tp *types.FileProps) *FileUpdateOne { + fuo.mutation.SetProps(tp) + return fuo +} + +// ClearProps clears the value of the "props" field. +func (fuo *FileUpdateOne) ClearProps() *FileUpdateOne { + fuo.mutation.ClearProps() + return fuo +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (fuo *FileUpdateOne) SetStoragePolicyFiles(i int) *FileUpdateOne { + fuo.mutation.SetStoragePolicyFiles(i) + return fuo +} + +// SetNillableStoragePolicyFiles sets the "storage_policy_files" field if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableStoragePolicyFiles(i *int) *FileUpdateOne { + if i != nil { + fuo.SetStoragePolicyFiles(*i) + } + return fuo +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (fuo *FileUpdateOne) ClearStoragePolicyFiles() *FileUpdateOne { + fuo.mutation.ClearStoragePolicyFiles() + return fuo +} + +// SetOwner sets the "owner" edge to the User entity. +func (fuo *FileUpdateOne) SetOwner(u *User) *FileUpdateOne { + return fuo.SetOwnerID(u.ID) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (fuo *FileUpdateOne) SetStoragePoliciesID(id int) *FileUpdateOne { + fuo.mutation.SetStoragePoliciesID(id) + return fuo +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableStoragePoliciesID(id *int) *FileUpdateOne { + if id != nil { + fuo = fuo.SetStoragePoliciesID(*id) + } + return fuo +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (fuo *FileUpdateOne) SetStoragePolicies(s *StoragePolicy) *FileUpdateOne { + return fuo.SetStoragePoliciesID(s.ID) +} + +// SetParentID sets the "parent" edge to the File entity by ID. +func (fuo *FileUpdateOne) SetParentID(id int) *FileUpdateOne { + fuo.mutation.SetParentID(id) + return fuo +} + +// SetNillableParentID sets the "parent" edge to the File entity by ID if the given value is not nil. +func (fuo *FileUpdateOne) SetNillableParentID(id *int) *FileUpdateOne { + if id != nil { + fuo = fuo.SetParentID(*id) + } + return fuo +} + +// SetParent sets the "parent" edge to the File entity. +func (fuo *FileUpdateOne) SetParent(f *File) *FileUpdateOne { + return fuo.SetParentID(f.ID) +} + +// AddChildIDs adds the "children" edge to the File entity by IDs. +func (fuo *FileUpdateOne) AddChildIDs(ids ...int) *FileUpdateOne { + fuo.mutation.AddChildIDs(ids...) + return fuo +} + +// AddChildren adds the "children" edges to the File entity. +func (fuo *FileUpdateOne) AddChildren(f ...*File) *FileUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return fuo.AddChildIDs(ids...) +} + +// AddMetadatumIDs adds the "metadata" edge to the Metadata entity by IDs. +func (fuo *FileUpdateOne) AddMetadatumIDs(ids ...int) *FileUpdateOne { + fuo.mutation.AddMetadatumIDs(ids...) + return fuo +} + +// AddMetadata adds the "metadata" edges to the Metadata entity. +func (fuo *FileUpdateOne) AddMetadata(m ...*Metadata) *FileUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return fuo.AddMetadatumIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (fuo *FileUpdateOne) AddEntityIDs(ids ...int) *FileUpdateOne { + fuo.mutation.AddEntityIDs(ids...) + return fuo +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (fuo *FileUpdateOne) AddEntities(e ...*Entity) *FileUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return fuo.AddEntityIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (fuo *FileUpdateOne) AddShareIDs(ids ...int) *FileUpdateOne { + fuo.mutation.AddShareIDs(ids...) + return fuo +} + +// AddShares adds the "shares" edges to the Share entity. +func (fuo *FileUpdateOne) AddShares(s ...*Share) *FileUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return fuo.AddShareIDs(ids...) +} + +// AddDirectLinkIDs adds the "direct_links" edge to the DirectLink entity by IDs. +func (fuo *FileUpdateOne) AddDirectLinkIDs(ids ...int) *FileUpdateOne { + fuo.mutation.AddDirectLinkIDs(ids...) + return fuo +} + +// AddDirectLinks adds the "direct_links" edges to the DirectLink entity. +func (fuo *FileUpdateOne) AddDirectLinks(d ...*DirectLink) *FileUpdateOne { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return fuo.AddDirectLinkIDs(ids...) +} + +// Mutation returns the FileMutation object of the builder. +func (fuo *FileUpdateOne) Mutation() *FileMutation { + return fuo.mutation +} + +// ClearOwner clears the "owner" edge to the User entity. +func (fuo *FileUpdateOne) ClearOwner() *FileUpdateOne { + fuo.mutation.ClearOwner() + return fuo +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (fuo *FileUpdateOne) ClearStoragePolicies() *FileUpdateOne { + fuo.mutation.ClearStoragePolicies() + return fuo +} + +// ClearParent clears the "parent" edge to the File entity. +func (fuo *FileUpdateOne) ClearParent() *FileUpdateOne { + fuo.mutation.ClearParent() + return fuo +} + +// ClearChildren clears all "children" edges to the File entity. +func (fuo *FileUpdateOne) ClearChildren() *FileUpdateOne { + fuo.mutation.ClearChildren() + return fuo +} + +// RemoveChildIDs removes the "children" edge to File entities by IDs. +func (fuo *FileUpdateOne) RemoveChildIDs(ids ...int) *FileUpdateOne { + fuo.mutation.RemoveChildIDs(ids...) + return fuo +} + +// RemoveChildren removes "children" edges to File entities. +func (fuo *FileUpdateOne) RemoveChildren(f ...*File) *FileUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return fuo.RemoveChildIDs(ids...) +} + +// ClearMetadata clears all "metadata" edges to the Metadata entity. +func (fuo *FileUpdateOne) ClearMetadata() *FileUpdateOne { + fuo.mutation.ClearMetadata() + return fuo +} + +// RemoveMetadatumIDs removes the "metadata" edge to Metadata entities by IDs. +func (fuo *FileUpdateOne) RemoveMetadatumIDs(ids ...int) *FileUpdateOne { + fuo.mutation.RemoveMetadatumIDs(ids...) + return fuo +} + +// RemoveMetadata removes "metadata" edges to Metadata entities. +func (fuo *FileUpdateOne) RemoveMetadata(m ...*Metadata) *FileUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return fuo.RemoveMetadatumIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (fuo *FileUpdateOne) ClearEntities() *FileUpdateOne { + fuo.mutation.ClearEntities() + return fuo +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (fuo *FileUpdateOne) RemoveEntityIDs(ids ...int) *FileUpdateOne { + fuo.mutation.RemoveEntityIDs(ids...) + return fuo +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (fuo *FileUpdateOne) RemoveEntities(e ...*Entity) *FileUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return fuo.RemoveEntityIDs(ids...) +} + +// ClearShares clears all "shares" edges to the Share entity. +func (fuo *FileUpdateOne) ClearShares() *FileUpdateOne { + fuo.mutation.ClearShares() + return fuo +} + +// RemoveShareIDs removes the "shares" edge to Share entities by IDs. +func (fuo *FileUpdateOne) RemoveShareIDs(ids ...int) *FileUpdateOne { + fuo.mutation.RemoveShareIDs(ids...) + return fuo +} + +// RemoveShares removes "shares" edges to Share entities. +func (fuo *FileUpdateOne) RemoveShares(s ...*Share) *FileUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return fuo.RemoveShareIDs(ids...) +} + +// ClearDirectLinks clears all "direct_links" edges to the DirectLink entity. +func (fuo *FileUpdateOne) ClearDirectLinks() *FileUpdateOne { + fuo.mutation.ClearDirectLinks() + return fuo +} + +// RemoveDirectLinkIDs removes the "direct_links" edge to DirectLink entities by IDs. +func (fuo *FileUpdateOne) RemoveDirectLinkIDs(ids ...int) *FileUpdateOne { + fuo.mutation.RemoveDirectLinkIDs(ids...) + return fuo +} + +// RemoveDirectLinks removes "direct_links" edges to DirectLink entities. +func (fuo *FileUpdateOne) RemoveDirectLinks(d ...*DirectLink) *FileUpdateOne { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return fuo.RemoveDirectLinkIDs(ids...) +} + +// Where appends a list predicates to the FileUpdate builder. +func (fuo *FileUpdateOne) Where(ps ...predicate.File) *FileUpdateOne { + fuo.mutation.Where(ps...) + return fuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (fuo *FileUpdateOne) Select(field string, fields ...string) *FileUpdateOne { + fuo.fields = append([]string{field}, fields...) + return fuo +} + +// Save executes the query and returns the updated File entity. +func (fuo *FileUpdateOne) Save(ctx context.Context) (*File, error) { + if err := fuo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, fuo.sqlSave, fuo.mutation, fuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (fuo *FileUpdateOne) SaveX(ctx context.Context) *File { + node, err := fuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (fuo *FileUpdateOne) Exec(ctx context.Context) error { + _, err := fuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (fuo *FileUpdateOne) ExecX(ctx context.Context) { + if err := fuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (fuo *FileUpdateOne) defaults() error { + if _, ok := fuo.mutation.UpdatedAt(); !ok { + if file.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized file.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := file.UpdateDefaultUpdatedAt() + fuo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (fuo *FileUpdateOne) check() error { + if _, ok := fuo.mutation.OwnerID(); fuo.mutation.OwnerCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "File.owner"`) + } + return nil +} + +func (fuo *FileUpdateOne) sqlSave(ctx context.Context) (_node *File, err error) { + if err := fuo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(file.Table, file.Columns, sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt)) + id, ok := fuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "File.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := fuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, file.FieldID) + for _, f := range fields { + if !file.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != file.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := fuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := fuo.mutation.UpdatedAt(); ok { + _spec.SetField(file.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := fuo.mutation.DeletedAt(); ok { + _spec.SetField(file.FieldDeletedAt, field.TypeTime, value) + } + if fuo.mutation.DeletedAtCleared() { + _spec.ClearField(file.FieldDeletedAt, field.TypeTime) + } + if value, ok := fuo.mutation.GetType(); ok { + _spec.SetField(file.FieldType, field.TypeInt, value) + } + if value, ok := fuo.mutation.AddedType(); ok { + _spec.AddField(file.FieldType, field.TypeInt, value) + } + if value, ok := fuo.mutation.Name(); ok { + _spec.SetField(file.FieldName, field.TypeString, value) + } + if value, ok := fuo.mutation.Size(); ok { + _spec.SetField(file.FieldSize, field.TypeInt64, value) + } + if value, ok := fuo.mutation.AddedSize(); ok { + _spec.AddField(file.FieldSize, field.TypeInt64, value) + } + if value, ok := fuo.mutation.PrimaryEntity(); ok { + _spec.SetField(file.FieldPrimaryEntity, field.TypeInt, value) + } + if value, ok := fuo.mutation.AddedPrimaryEntity(); ok { + _spec.AddField(file.FieldPrimaryEntity, field.TypeInt, value) + } + if fuo.mutation.PrimaryEntityCleared() { + _spec.ClearField(file.FieldPrimaryEntity, field.TypeInt) + } + if value, ok := fuo.mutation.IsSymbolic(); ok { + _spec.SetField(file.FieldIsSymbolic, field.TypeBool, value) + } + if value, ok := fuo.mutation.Props(); ok { + _spec.SetField(file.FieldProps, field.TypeJSON, value) + } + if fuo.mutation.PropsCleared() { + _spec.ClearField(file.FieldProps, field.TypeJSON) + } + if fuo.mutation.OwnerCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.OwnerTable, + Columns: []string{file.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.OwnerIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.OwnerTable, + Columns: []string{file.OwnerColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.StoragePoliciesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.StoragePoliciesTable, + Columns: []string{file.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.StoragePoliciesTable, + Columns: []string{file.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.ParentCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.ParentTable, + Columns: []string{file.ParentColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.ParentIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: file.ParentTable, + Columns: []string{file.ParentColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.ChildrenCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.RemovedChildrenIDs(); len(nodes) > 0 && !fuo.mutation.ChildrenCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.ChildrenIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.ChildrenTable, + Columns: []string{file.ChildrenColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.MetadataCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.RemovedMetadataIDs(); len(nodes) > 0 && !fuo.mutation.MetadataCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.MetadataIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.MetadataTable, + Columns: []string{file.MetadataColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !fuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: file.EntitiesTable, + Columns: file.EntitiesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.RemovedSharesIDs(); len(nodes) > 0 && !fuo.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.SharesTable, + Columns: []string{file.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if fuo.mutation.DirectLinksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.RemovedDirectLinksIDs(); len(nodes) > 0 && !fuo.mutation.DirectLinksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := fuo.mutation.DirectLinksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: file.DirectLinksTable, + Columns: []string{file.DirectLinksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(directlink.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &File{config: fuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, fuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{file.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + fuo.mutation.done = true + return _node, nil +} diff --git a/ent/generate.go b/ent/generate.go new file mode 100644 index 00000000..8232761c --- /dev/null +++ b/ent/generate.go @@ -0,0 +1,3 @@ +package ent + +//go:generate go run -mod=mod entc.go diff --git a/ent/group.go b/ent/group.go new file mode 100644 index 00000000..80882cb5 --- /dev/null +++ b/ent/group.go @@ -0,0 +1,265 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// Group is the model entity for the Group schema. +type Group struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // MaxStorage holds the value of the "max_storage" field. + MaxStorage int64 `json:"max_storage,omitempty"` + // SpeedLimit holds the value of the "speed_limit" field. + SpeedLimit int `json:"speed_limit,omitempty"` + // Permissions holds the value of the "permissions" field. + Permissions *boolset.BooleanSet `json:"permissions,omitempty"` + // Settings holds the value of the "settings" field. + Settings *types.GroupSetting `json:"settings,omitempty"` + // StoragePolicyID holds the value of the "storage_policy_id" field. + StoragePolicyID int `json:"storage_policy_id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the GroupQuery when eager-loading is set. + Edges GroupEdges `json:"edges"` + selectValues sql.SelectValues +} + +// GroupEdges holds the relations/edges for other nodes in the graph. +type GroupEdges struct { + // Users holds the value of the users edge. + Users []*User `json:"users,omitempty"` + // StoragePolicies holds the value of the storage_policies edge. + StoragePolicies *StoragePolicy `json:"storage_policies,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// UsersOrErr returns the Users value or an error if the edge +// was not loaded in eager-loading. +func (e GroupEdges) UsersOrErr() ([]*User, error) { + if e.loadedTypes[0] { + return e.Users, nil + } + return nil, &NotLoadedError{edge: "users"} +} + +// StoragePoliciesOrErr returns the StoragePolicies value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e GroupEdges) StoragePoliciesOrErr() (*StoragePolicy, error) { + if e.loadedTypes[1] { + if e.StoragePolicies == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: storagepolicy.Label} + } + return e.StoragePolicies, nil + } + return nil, &NotLoadedError{edge: "storage_policies"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Group) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case group.FieldSettings: + values[i] = new([]byte) + case group.FieldPermissions: + values[i] = new(boolset.BooleanSet) + case group.FieldID, group.FieldMaxStorage, group.FieldSpeedLimit, group.FieldStoragePolicyID: + values[i] = new(sql.NullInt64) + case group.FieldName: + values[i] = new(sql.NullString) + case group.FieldCreatedAt, group.FieldUpdatedAt, group.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Group fields. +func (gr *Group) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case group.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + gr.ID = int(value.Int64) + case group.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + gr.CreatedAt = value.Time + } + case group.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + gr.UpdatedAt = value.Time + } + case group.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + gr.DeletedAt = new(time.Time) + *gr.DeletedAt = value.Time + } + case group.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + gr.Name = value.String + } + case group.FieldMaxStorage: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field max_storage", values[i]) + } else if value.Valid { + gr.MaxStorage = value.Int64 + } + case group.FieldSpeedLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field speed_limit", values[i]) + } else if value.Valid { + gr.SpeedLimit = int(value.Int64) + } + case group.FieldPermissions: + if value, ok := values[i].(*boolset.BooleanSet); !ok { + return fmt.Errorf("unexpected type %T for field permissions", values[i]) + } else if value != nil { + gr.Permissions = value + } + case group.FieldSettings: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field settings", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &gr.Settings); err != nil { + return fmt.Errorf("unmarshal field settings: %w", err) + } + } + case group.FieldStoragePolicyID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field storage_policy_id", values[i]) + } else if value.Valid { + gr.StoragePolicyID = int(value.Int64) + } + default: + gr.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Group. +// This includes values selected through modifiers, order, etc. +func (gr *Group) Value(name string) (ent.Value, error) { + return gr.selectValues.Get(name) +} + +// QueryUsers queries the "users" edge of the Group entity. +func (gr *Group) QueryUsers() *UserQuery { + return NewGroupClient(gr.config).QueryUsers(gr) +} + +// QueryStoragePolicies queries the "storage_policies" edge of the Group entity. +func (gr *Group) QueryStoragePolicies() *StoragePolicyQuery { + return NewGroupClient(gr.config).QueryStoragePolicies(gr) +} + +// Update returns a builder for updating this Group. +// Note that you need to call Group.Unwrap() before calling this method if this Group +// was returned from a transaction, and the transaction was committed or rolled back. +func (gr *Group) Update() *GroupUpdateOne { + return NewGroupClient(gr.config).UpdateOne(gr) +} + +// Unwrap unwraps the Group entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (gr *Group) Unwrap() *Group { + _tx, ok := gr.config.driver.(*txDriver) + if !ok { + panic("ent: Group is not a transactional entity") + } + gr.config.driver = _tx.drv + return gr +} + +// String implements the fmt.Stringer. +func (gr *Group) String() string { + var builder strings.Builder + builder.WriteString("Group(") + builder.WriteString(fmt.Sprintf("id=%v, ", gr.ID)) + builder.WriteString("created_at=") + builder.WriteString(gr.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(gr.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := gr.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(gr.Name) + builder.WriteString(", ") + builder.WriteString("max_storage=") + builder.WriteString(fmt.Sprintf("%v", gr.MaxStorage)) + builder.WriteString(", ") + builder.WriteString("speed_limit=") + builder.WriteString(fmt.Sprintf("%v", gr.SpeedLimit)) + builder.WriteString(", ") + builder.WriteString("permissions=") + builder.WriteString(fmt.Sprintf("%v", gr.Permissions)) + builder.WriteString(", ") + builder.WriteString("settings=") + builder.WriteString(fmt.Sprintf("%v", gr.Settings)) + builder.WriteString(", ") + builder.WriteString("storage_policy_id=") + builder.WriteString(fmt.Sprintf("%v", gr.StoragePolicyID)) + builder.WriteByte(')') + return builder.String() +} + +// SetUsers manually set the edge as loaded state. +func (e *Group) SetUsers(v []*User) { + e.Edges.Users = v + e.Edges.loadedTypes[0] = true +} + +// SetStoragePolicies manually set the edge as loaded state. +func (e *Group) SetStoragePolicies(v *StoragePolicy) { + e.Edges.StoragePolicies = v + e.Edges.loadedTypes[1] = true +} + +// Groups is a parsable slice of Group. +type Groups []*Group diff --git a/ent/group/group.go b/ent/group/group.go new file mode 100644 index 00000000..7ceee545 --- /dev/null +++ b/ent/group/group.go @@ -0,0 +1,177 @@ +// Code generated by ent, DO NOT EDIT. + +package group + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +const ( + // Label holds the string label denoting the group type in the database. + Label = "group" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldMaxStorage holds the string denoting the max_storage field in the database. + FieldMaxStorage = "max_storage" + // FieldSpeedLimit holds the string denoting the speed_limit field in the database. + FieldSpeedLimit = "speed_limit" + // FieldPermissions holds the string denoting the permissions field in the database. + FieldPermissions = "permissions" + // FieldSettings holds the string denoting the settings field in the database. + FieldSettings = "settings" + // FieldStoragePolicyID holds the string denoting the storage_policy_id field in the database. + FieldStoragePolicyID = "storage_policy_id" + // EdgeUsers holds the string denoting the users edge name in mutations. + EdgeUsers = "users" + // EdgeStoragePolicies holds the string denoting the storage_policies edge name in mutations. + EdgeStoragePolicies = "storage_policies" + // Table holds the table name of the group in the database. + Table = "groups" + // UsersTable is the table that holds the users relation/edge. + UsersTable = "users" + // UsersInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UsersInverseTable = "users" + // UsersColumn is the table column denoting the users relation/edge. + UsersColumn = "group_users" + // StoragePoliciesTable is the table that holds the storage_policies relation/edge. + StoragePoliciesTable = "groups" + // StoragePoliciesInverseTable is the table name for the StoragePolicy entity. + // It exists in this package in order to avoid circular dependency with the "storagepolicy" package. + StoragePoliciesInverseTable = "storage_policies" + // StoragePoliciesColumn is the table column denoting the storage_policies relation/edge. + StoragePoliciesColumn = "storage_policy_id" +) + +// Columns holds all SQL columns for group fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldMaxStorage, + FieldSpeedLimit, + FieldPermissions, + FieldSettings, + FieldStoragePolicyID, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultSettings holds the default value on creation for the "settings" field. + DefaultSettings *types.GroupSetting +) + +// OrderOption defines the ordering options for the Group queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByMaxStorage orders the results by the max_storage field. +func ByMaxStorage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMaxStorage, opts...).ToFunc() +} + +// BySpeedLimit orders the results by the speed_limit field. +func BySpeedLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSpeedLimit, opts...).ToFunc() +} + +// ByStoragePolicyID orders the results by the storage_policy_id field. +func ByStoragePolicyID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStoragePolicyID, opts...).ToFunc() +} + +// ByUsersCount orders the results by users count. +func ByUsersCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsersStep(), opts...) + } +} + +// ByUsers orders the results by users terms. +func ByUsers(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsersStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByStoragePoliciesField orders the results by storage_policies field. +func ByStoragePoliciesField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStoragePoliciesStep(), sql.OrderByField(field, opts...)) + } +} +func newUsersStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsersInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsersTable, UsersColumn), + ) +} +func newStoragePoliciesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StoragePoliciesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePoliciesTable, StoragePoliciesColumn), + ) +} diff --git a/ent/group/where.go b/ent/group/where.go new file mode 100644 index 00000000..4c414ff5 --- /dev/null +++ b/ent/group/where.go @@ -0,0 +1,533 @@ +// Code generated by ent, DO NOT EDIT. + +package group + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldName, v)) +} + +// MaxStorage applies equality check predicate on the "max_storage" field. It's identical to MaxStorageEQ. +func MaxStorage(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMaxStorage, v)) +} + +// SpeedLimit applies equality check predicate on the "speed_limit" field. It's identical to SpeedLimitEQ. +func SpeedLimit(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSpeedLimit, v)) +} + +// Permissions applies equality check predicate on the "permissions" field. It's identical to PermissionsEQ. +func Permissions(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldPermissions, v)) +} + +// StoragePolicyID applies equality check predicate on the "storage_policy_id" field. It's identical to StoragePolicyIDEQ. +func StoragePolicyID(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldStoragePolicyID, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Group { + return predicate.Group(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Group { + return predicate.Group(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Group { + return predicate.Group(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Group { + return predicate.Group(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Group { + return predicate.Group(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Group { + return predicate.Group(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Group { + return predicate.Group(sql.FieldContainsFold(FieldName, v)) +} + +// MaxStorageEQ applies the EQ predicate on the "max_storage" field. +func MaxStorageEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMaxStorage, v)) +} + +// MaxStorageNEQ applies the NEQ predicate on the "max_storage" field. +func MaxStorageNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldMaxStorage, v)) +} + +// MaxStorageIn applies the In predicate on the "max_storage" field. +func MaxStorageIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldMaxStorage, vs...)) +} + +// MaxStorageNotIn applies the NotIn predicate on the "max_storage" field. +func MaxStorageNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldMaxStorage, vs...)) +} + +// MaxStorageGT applies the GT predicate on the "max_storage" field. +func MaxStorageGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldMaxStorage, v)) +} + +// MaxStorageGTE applies the GTE predicate on the "max_storage" field. +func MaxStorageGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldMaxStorage, v)) +} + +// MaxStorageLT applies the LT predicate on the "max_storage" field. +func MaxStorageLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldMaxStorage, v)) +} + +// MaxStorageLTE applies the LTE predicate on the "max_storage" field. +func MaxStorageLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldMaxStorage, v)) +} + +// MaxStorageIsNil applies the IsNil predicate on the "max_storage" field. +func MaxStorageIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldMaxStorage)) +} + +// MaxStorageNotNil applies the NotNil predicate on the "max_storage" field. +func MaxStorageNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldMaxStorage)) +} + +// SpeedLimitEQ applies the EQ predicate on the "speed_limit" field. +func SpeedLimitEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSpeedLimit, v)) +} + +// SpeedLimitNEQ applies the NEQ predicate on the "speed_limit" field. +func SpeedLimitNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSpeedLimit, v)) +} + +// SpeedLimitIn applies the In predicate on the "speed_limit" field. +func SpeedLimitIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSpeedLimit, vs...)) +} + +// SpeedLimitNotIn applies the NotIn predicate on the "speed_limit" field. +func SpeedLimitNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSpeedLimit, vs...)) +} + +// SpeedLimitGT applies the GT predicate on the "speed_limit" field. +func SpeedLimitGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSpeedLimit, v)) +} + +// SpeedLimitGTE applies the GTE predicate on the "speed_limit" field. +func SpeedLimitGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSpeedLimit, v)) +} + +// SpeedLimitLT applies the LT predicate on the "speed_limit" field. +func SpeedLimitLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSpeedLimit, v)) +} + +// SpeedLimitLTE applies the LTE predicate on the "speed_limit" field. +func SpeedLimitLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSpeedLimit, v)) +} + +// SpeedLimitIsNil applies the IsNil predicate on the "speed_limit" field. +func SpeedLimitIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSpeedLimit)) +} + +// SpeedLimitNotNil applies the NotNil predicate on the "speed_limit" field. +func SpeedLimitNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSpeedLimit)) +} + +// PermissionsEQ applies the EQ predicate on the "permissions" field. +func PermissionsEQ(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldPermissions, v)) +} + +// PermissionsNEQ applies the NEQ predicate on the "permissions" field. +func PermissionsNEQ(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldPermissions, v)) +} + +// PermissionsIn applies the In predicate on the "permissions" field. +func PermissionsIn(vs ...*boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldIn(FieldPermissions, vs...)) +} + +// PermissionsNotIn applies the NotIn predicate on the "permissions" field. +func PermissionsNotIn(vs ...*boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldPermissions, vs...)) +} + +// PermissionsGT applies the GT predicate on the "permissions" field. +func PermissionsGT(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldGT(FieldPermissions, v)) +} + +// PermissionsGTE applies the GTE predicate on the "permissions" field. +func PermissionsGTE(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldPermissions, v)) +} + +// PermissionsLT applies the LT predicate on the "permissions" field. +func PermissionsLT(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldLT(FieldPermissions, v)) +} + +// PermissionsLTE applies the LTE predicate on the "permissions" field. +func PermissionsLTE(v *boolset.BooleanSet) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldPermissions, v)) +} + +// SettingsIsNil applies the IsNil predicate on the "settings" field. +func SettingsIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSettings)) +} + +// SettingsNotNil applies the NotNil predicate on the "settings" field. +func SettingsNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSettings)) +} + +// StoragePolicyIDEQ applies the EQ predicate on the "storage_policy_id" field. +func StoragePolicyIDEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldStoragePolicyID, v)) +} + +// StoragePolicyIDNEQ applies the NEQ predicate on the "storage_policy_id" field. +func StoragePolicyIDNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldStoragePolicyID, v)) +} + +// StoragePolicyIDIn applies the In predicate on the "storage_policy_id" field. +func StoragePolicyIDIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldStoragePolicyID, vs...)) +} + +// StoragePolicyIDNotIn applies the NotIn predicate on the "storage_policy_id" field. +func StoragePolicyIDNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldStoragePolicyID, vs...)) +} + +// StoragePolicyIDIsNil applies the IsNil predicate on the "storage_policy_id" field. +func StoragePolicyIDIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldStoragePolicyID)) +} + +// StoragePolicyIDNotNil applies the NotNil predicate on the "storage_policy_id" field. +func StoragePolicyIDNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldStoragePolicyID)) +} + +// HasUsers applies the HasEdge predicate on the "users" edge. +func HasUsers() predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsersTable, UsersColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsersWith applies the HasEdge predicate on the "users" edge with a given conditions (other predicates). +func HasUsersWith(preds ...predicate.User) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := newUsersStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasStoragePolicies applies the HasEdge predicate on the "storage_policies" edge. +func HasStoragePolicies() predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StoragePoliciesTable, StoragePoliciesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasStoragePoliciesWith applies the HasEdge predicate on the "storage_policies" edge with a given conditions (other predicates). +func HasStoragePoliciesWith(preds ...predicate.StoragePolicy) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := newStoragePoliciesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Group) predicate.Group { + return predicate.Group(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Group) predicate.Group { + return predicate.Group(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Group) predicate.Group { + return predicate.Group(sql.NotPredicates(p)) +} diff --git a/ent/group_create.go b/ent/group_create.go new file mode 100644 index 00000000..e11afdd6 --- /dev/null +++ b/ent/group_create.go @@ -0,0 +1,1130 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// GroupCreate is the builder for creating a Group entity. +type GroupCreate struct { + config + mutation *GroupMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (gc *GroupCreate) SetCreatedAt(t time.Time) *GroupCreate { + gc.mutation.SetCreatedAt(t) + return gc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (gc *GroupCreate) SetNillableCreatedAt(t *time.Time) *GroupCreate { + if t != nil { + gc.SetCreatedAt(*t) + } + return gc +} + +// SetUpdatedAt sets the "updated_at" field. +func (gc *GroupCreate) SetUpdatedAt(t time.Time) *GroupCreate { + gc.mutation.SetUpdatedAt(t) + return gc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (gc *GroupCreate) SetNillableUpdatedAt(t *time.Time) *GroupCreate { + if t != nil { + gc.SetUpdatedAt(*t) + } + return gc +} + +// SetDeletedAt sets the "deleted_at" field. +func (gc *GroupCreate) SetDeletedAt(t time.Time) *GroupCreate { + gc.mutation.SetDeletedAt(t) + return gc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (gc *GroupCreate) SetNillableDeletedAt(t *time.Time) *GroupCreate { + if t != nil { + gc.SetDeletedAt(*t) + } + return gc +} + +// SetName sets the "name" field. +func (gc *GroupCreate) SetName(s string) *GroupCreate { + gc.mutation.SetName(s) + return gc +} + +// SetMaxStorage sets the "max_storage" field. +func (gc *GroupCreate) SetMaxStorage(i int64) *GroupCreate { + gc.mutation.SetMaxStorage(i) + return gc +} + +// SetNillableMaxStorage sets the "max_storage" field if the given value is not nil. +func (gc *GroupCreate) SetNillableMaxStorage(i *int64) *GroupCreate { + if i != nil { + gc.SetMaxStorage(*i) + } + return gc +} + +// SetSpeedLimit sets the "speed_limit" field. +func (gc *GroupCreate) SetSpeedLimit(i int) *GroupCreate { + gc.mutation.SetSpeedLimit(i) + return gc +} + +// SetNillableSpeedLimit sets the "speed_limit" field if the given value is not nil. +func (gc *GroupCreate) SetNillableSpeedLimit(i *int) *GroupCreate { + if i != nil { + gc.SetSpeedLimit(*i) + } + return gc +} + +// SetPermissions sets the "permissions" field. +func (gc *GroupCreate) SetPermissions(bs *boolset.BooleanSet) *GroupCreate { + gc.mutation.SetPermissions(bs) + return gc +} + +// SetSettings sets the "settings" field. +func (gc *GroupCreate) SetSettings(ts *types.GroupSetting) *GroupCreate { + gc.mutation.SetSettings(ts) + return gc +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (gc *GroupCreate) SetStoragePolicyID(i int) *GroupCreate { + gc.mutation.SetStoragePolicyID(i) + return gc +} + +// SetNillableStoragePolicyID sets the "storage_policy_id" field if the given value is not nil. +func (gc *GroupCreate) SetNillableStoragePolicyID(i *int) *GroupCreate { + if i != nil { + gc.SetStoragePolicyID(*i) + } + return gc +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (gc *GroupCreate) AddUserIDs(ids ...int) *GroupCreate { + gc.mutation.AddUserIDs(ids...) + return gc +} + +// AddUsers adds the "users" edges to the User entity. +func (gc *GroupCreate) AddUsers(u ...*User) *GroupCreate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gc.AddUserIDs(ids...) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (gc *GroupCreate) SetStoragePoliciesID(id int) *GroupCreate { + gc.mutation.SetStoragePoliciesID(id) + return gc +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (gc *GroupCreate) SetNillableStoragePoliciesID(id *int) *GroupCreate { + if id != nil { + gc = gc.SetStoragePoliciesID(*id) + } + return gc +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (gc *GroupCreate) SetStoragePolicies(s *StoragePolicy) *GroupCreate { + return gc.SetStoragePoliciesID(s.ID) +} + +// Mutation returns the GroupMutation object of the builder. +func (gc *GroupCreate) Mutation() *GroupMutation { + return gc.mutation +} + +// Save creates the Group in the database. +func (gc *GroupCreate) Save(ctx context.Context) (*Group, error) { + if err := gc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, gc.sqlSave, gc.mutation, gc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (gc *GroupCreate) SaveX(ctx context.Context) *Group { + v, err := gc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (gc *GroupCreate) Exec(ctx context.Context) error { + _, err := gc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (gc *GroupCreate) ExecX(ctx context.Context) { + if err := gc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (gc *GroupCreate) defaults() error { + if _, ok := gc.mutation.CreatedAt(); !ok { + if group.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized group.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := group.DefaultCreatedAt() + gc.mutation.SetCreatedAt(v) + } + if _, ok := gc.mutation.UpdatedAt(); !ok { + if group.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized group.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := group.DefaultUpdatedAt() + gc.mutation.SetUpdatedAt(v) + } + if _, ok := gc.mutation.Settings(); !ok { + v := group.DefaultSettings + gc.mutation.SetSettings(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (gc *GroupCreate) check() error { + if _, ok := gc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Group.created_at"`)} + } + if _, ok := gc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Group.updated_at"`)} + } + if _, ok := gc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Group.name"`)} + } + if _, ok := gc.mutation.Permissions(); !ok { + return &ValidationError{Name: "permissions", err: errors.New(`ent: missing required field "Group.permissions"`)} + } + return nil +} + +func (gc *GroupCreate) sqlSave(ctx context.Context) (*Group, error) { + if err := gc.check(); err != nil { + return nil, err + } + _node, _spec := gc.createSpec() + if err := sqlgraph.CreateNode(ctx, gc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + gc.mutation.id = &_node.ID + gc.mutation.done = true + return _node, nil +} + +func (gc *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { + var ( + _node = &Group{config: gc.config} + _spec = sqlgraph.NewCreateSpec(group.Table, sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt)) + ) + + if id, ok := gc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = gc.conflict + if value, ok := gc.mutation.CreatedAt(); ok { + _spec.SetField(group.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := gc.mutation.UpdatedAt(); ok { + _spec.SetField(group.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := gc.mutation.DeletedAt(); ok { + _spec.SetField(group.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := gc.mutation.Name(); ok { + _spec.SetField(group.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := gc.mutation.MaxStorage(); ok { + _spec.SetField(group.FieldMaxStorage, field.TypeInt64, value) + _node.MaxStorage = value + } + if value, ok := gc.mutation.SpeedLimit(); ok { + _spec.SetField(group.FieldSpeedLimit, field.TypeInt, value) + _node.SpeedLimit = value + } + if value, ok := gc.mutation.Permissions(); ok { + _spec.SetField(group.FieldPermissions, field.TypeBytes, value) + _node.Permissions = value + } + if value, ok := gc.mutation.Settings(); ok { + _spec.SetField(group.FieldSettings, field.TypeJSON, value) + _node.Settings = value + } + if nodes := gc.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := gc.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: group.StoragePoliciesTable, + Columns: []string{group.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.StoragePolicyID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Group.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.GroupUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (gc *GroupCreate) OnConflict(opts ...sql.ConflictOption) *GroupUpsertOne { + gc.conflict = opts + return &GroupUpsertOne{ + create: gc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (gc *GroupCreate) OnConflictColumns(columns ...string) *GroupUpsertOne { + gc.conflict = append(gc.conflict, sql.ConflictColumns(columns...)) + return &GroupUpsertOne{ + create: gc, + } +} + +type ( + // GroupUpsertOne is the builder for "upsert"-ing + // one Group node. + GroupUpsertOne struct { + create *GroupCreate + } + + // GroupUpsert is the "OnConflict" setter. + GroupUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *GroupUpsert) SetUpdatedAt(v time.Time) *GroupUpsert { + u.Set(group.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *GroupUpsert) UpdateUpdatedAt() *GroupUpsert { + u.SetExcluded(group.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *GroupUpsert) SetDeletedAt(v time.Time) *GroupUpsert { + u.Set(group.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *GroupUpsert) UpdateDeletedAt() *GroupUpsert { + u.SetExcluded(group.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *GroupUpsert) ClearDeletedAt() *GroupUpsert { + u.SetNull(group.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *GroupUpsert) SetName(v string) *GroupUpsert { + u.Set(group.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *GroupUpsert) UpdateName() *GroupUpsert { + u.SetExcluded(group.FieldName) + return u +} + +// SetMaxStorage sets the "max_storage" field. +func (u *GroupUpsert) SetMaxStorage(v int64) *GroupUpsert { + u.Set(group.FieldMaxStorage, v) + return u +} + +// UpdateMaxStorage sets the "max_storage" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMaxStorage() *GroupUpsert { + u.SetExcluded(group.FieldMaxStorage) + return u +} + +// AddMaxStorage adds v to the "max_storage" field. +func (u *GroupUpsert) AddMaxStorage(v int64) *GroupUpsert { + u.Add(group.FieldMaxStorage, v) + return u +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (u *GroupUpsert) ClearMaxStorage() *GroupUpsert { + u.SetNull(group.FieldMaxStorage) + return u +} + +// SetSpeedLimit sets the "speed_limit" field. +func (u *GroupUpsert) SetSpeedLimit(v int) *GroupUpsert { + u.Set(group.FieldSpeedLimit, v) + return u +} + +// UpdateSpeedLimit sets the "speed_limit" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSpeedLimit() *GroupUpsert { + u.SetExcluded(group.FieldSpeedLimit) + return u +} + +// AddSpeedLimit adds v to the "speed_limit" field. +func (u *GroupUpsert) AddSpeedLimit(v int) *GroupUpsert { + u.Add(group.FieldSpeedLimit, v) + return u +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (u *GroupUpsert) ClearSpeedLimit() *GroupUpsert { + u.SetNull(group.FieldSpeedLimit) + return u +} + +// SetPermissions sets the "permissions" field. +func (u *GroupUpsert) SetPermissions(v *boolset.BooleanSet) *GroupUpsert { + u.Set(group.FieldPermissions, v) + return u +} + +// UpdatePermissions sets the "permissions" field to the value that was provided on create. +func (u *GroupUpsert) UpdatePermissions() *GroupUpsert { + u.SetExcluded(group.FieldPermissions) + return u +} + +// SetSettings sets the "settings" field. +func (u *GroupUpsert) SetSettings(v *types.GroupSetting) *GroupUpsert { + u.Set(group.FieldSettings, v) + return u +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSettings() *GroupUpsert { + u.SetExcluded(group.FieldSettings) + return u +} + +// ClearSettings clears the value of the "settings" field. +func (u *GroupUpsert) ClearSettings() *GroupUpsert { + u.SetNull(group.FieldSettings) + return u +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (u *GroupUpsert) SetStoragePolicyID(v int) *GroupUpsert { + u.Set(group.FieldStoragePolicyID, v) + return u +} + +// UpdateStoragePolicyID sets the "storage_policy_id" field to the value that was provided on create. +func (u *GroupUpsert) UpdateStoragePolicyID() *GroupUpsert { + u.SetExcluded(group.FieldStoragePolicyID) + return u +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (u *GroupUpsert) ClearStoragePolicyID() *GroupUpsert { + u.SetNull(group.FieldStoragePolicyID) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *GroupUpsertOne) UpdateNewValues() *GroupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(group.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *GroupUpsertOne) Ignore() *GroupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *GroupUpsertOne) DoNothing() *GroupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the GroupCreate.OnConflict +// documentation for more info. +func (u *GroupUpsertOne) Update(set func(*GroupUpsert)) *GroupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&GroupUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *GroupUpsertOne) SetUpdatedAt(v time.Time) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateUpdatedAt() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *GroupUpsertOne) SetDeletedAt(v time.Time) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateDeletedAt() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *GroupUpsertOne) ClearDeletedAt() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *GroupUpsertOne) SetName(v string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateName() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateName() + }) +} + +// SetMaxStorage sets the "max_storage" field. +func (u *GroupUpsertOne) SetMaxStorage(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMaxStorage(v) + }) +} + +// AddMaxStorage adds v to the "max_storage" field. +func (u *GroupUpsertOne) AddMaxStorage(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddMaxStorage(v) + }) +} + +// UpdateMaxStorage sets the "max_storage" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMaxStorage() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMaxStorage() + }) +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (u *GroupUpsertOne) ClearMaxStorage() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearMaxStorage() + }) +} + +// SetSpeedLimit sets the "speed_limit" field. +func (u *GroupUpsertOne) SetSpeedLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSpeedLimit(v) + }) +} + +// AddSpeedLimit adds v to the "speed_limit" field. +func (u *GroupUpsertOne) AddSpeedLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSpeedLimit(v) + }) +} + +// UpdateSpeedLimit sets the "speed_limit" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSpeedLimit() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSpeedLimit() + }) +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (u *GroupUpsertOne) ClearSpeedLimit() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSpeedLimit() + }) +} + +// SetPermissions sets the "permissions" field. +func (u *GroupUpsertOne) SetPermissions(v *boolset.BooleanSet) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetPermissions(v) + }) +} + +// UpdatePermissions sets the "permissions" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdatePermissions() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdatePermissions() + }) +} + +// SetSettings sets the "settings" field. +func (u *GroupUpsertOne) SetSettings(v *types.GroupSetting) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSettings() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *GroupUpsertOne) ClearSettings() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSettings() + }) +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (u *GroupUpsertOne) SetStoragePolicyID(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetStoragePolicyID(v) + }) +} + +// UpdateStoragePolicyID sets the "storage_policy_id" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateStoragePolicyID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateStoragePolicyID() + }) +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (u *GroupUpsertOne) ClearStoragePolicyID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearStoragePolicyID() + }) +} + +// Exec executes the query. +func (u *GroupUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for GroupCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *GroupUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *GroupUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *GroupUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *GroupCreate) SetRawID(t int) *GroupCreate { + m.mutation.SetRawID(t) + return m +} + +// GroupCreateBulk is the builder for creating many Group entities in bulk. +type GroupCreateBulk struct { + config + err error + builders []*GroupCreate + conflict []sql.ConflictOption +} + +// Save creates the Group entities in the database. +func (gcb *GroupCreateBulk) Save(ctx context.Context) ([]*Group, error) { + if gcb.err != nil { + return nil, gcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(gcb.builders)) + nodes := make([]*Group, len(gcb.builders)) + mutators := make([]Mutator, len(gcb.builders)) + for i := range gcb.builders { + func(i int, root context.Context) { + builder := gcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*GroupMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, gcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = gcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, gcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, gcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (gcb *GroupCreateBulk) SaveX(ctx context.Context) []*Group { + v, err := gcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (gcb *GroupCreateBulk) Exec(ctx context.Context) error { + _, err := gcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (gcb *GroupCreateBulk) ExecX(ctx context.Context) { + if err := gcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Group.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.GroupUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (gcb *GroupCreateBulk) OnConflict(opts ...sql.ConflictOption) *GroupUpsertBulk { + gcb.conflict = opts + return &GroupUpsertBulk{ + create: gcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (gcb *GroupCreateBulk) OnConflictColumns(columns ...string) *GroupUpsertBulk { + gcb.conflict = append(gcb.conflict, sql.ConflictColumns(columns...)) + return &GroupUpsertBulk{ + create: gcb, + } +} + +// GroupUpsertBulk is the builder for "upsert"-ing +// a bulk of Group nodes. +type GroupUpsertBulk struct { + create *GroupCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *GroupUpsertBulk) UpdateNewValues() *GroupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(group.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Group.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *GroupUpsertBulk) Ignore() *GroupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *GroupUpsertBulk) DoNothing() *GroupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the GroupCreateBulk.OnConflict +// documentation for more info. +func (u *GroupUpsertBulk) Update(set func(*GroupUpsert)) *GroupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&GroupUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *GroupUpsertBulk) SetUpdatedAt(v time.Time) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateUpdatedAt() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *GroupUpsertBulk) SetDeletedAt(v time.Time) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateDeletedAt() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *GroupUpsertBulk) ClearDeletedAt() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *GroupUpsertBulk) SetName(v string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateName() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateName() + }) +} + +// SetMaxStorage sets the "max_storage" field. +func (u *GroupUpsertBulk) SetMaxStorage(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMaxStorage(v) + }) +} + +// AddMaxStorage adds v to the "max_storage" field. +func (u *GroupUpsertBulk) AddMaxStorage(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddMaxStorage(v) + }) +} + +// UpdateMaxStorage sets the "max_storage" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMaxStorage() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMaxStorage() + }) +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (u *GroupUpsertBulk) ClearMaxStorage() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearMaxStorage() + }) +} + +// SetSpeedLimit sets the "speed_limit" field. +func (u *GroupUpsertBulk) SetSpeedLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSpeedLimit(v) + }) +} + +// AddSpeedLimit adds v to the "speed_limit" field. +func (u *GroupUpsertBulk) AddSpeedLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSpeedLimit(v) + }) +} + +// UpdateSpeedLimit sets the "speed_limit" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSpeedLimit() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSpeedLimit() + }) +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (u *GroupUpsertBulk) ClearSpeedLimit() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSpeedLimit() + }) +} + +// SetPermissions sets the "permissions" field. +func (u *GroupUpsertBulk) SetPermissions(v *boolset.BooleanSet) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetPermissions(v) + }) +} + +// UpdatePermissions sets the "permissions" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdatePermissions() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdatePermissions() + }) +} + +// SetSettings sets the "settings" field. +func (u *GroupUpsertBulk) SetSettings(v *types.GroupSetting) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSettings() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *GroupUpsertBulk) ClearSettings() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSettings() + }) +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (u *GroupUpsertBulk) SetStoragePolicyID(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetStoragePolicyID(v) + }) +} + +// UpdateStoragePolicyID sets the "storage_policy_id" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateStoragePolicyID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateStoragePolicyID() + }) +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (u *GroupUpsertBulk) ClearStoragePolicyID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearStoragePolicyID() + }) +} + +// Exec executes the query. +func (u *GroupUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the GroupCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for GroupCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *GroupUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/group_delete.go b/ent/group_delete.go new file mode 100644 index 00000000..96abd8f3 --- /dev/null +++ b/ent/group_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// GroupDelete is the builder for deleting a Group entity. +type GroupDelete struct { + config + hooks []Hook + mutation *GroupMutation +} + +// Where appends a list predicates to the GroupDelete builder. +func (gd *GroupDelete) Where(ps ...predicate.Group) *GroupDelete { + gd.mutation.Where(ps...) + return gd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (gd *GroupDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, gd.sqlExec, gd.mutation, gd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (gd *GroupDelete) ExecX(ctx context.Context) int { + n, err := gd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (gd *GroupDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(group.Table, sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt)) + if ps := gd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, gd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + gd.mutation.done = true + return affected, err +} + +// GroupDeleteOne is the builder for deleting a single Group entity. +type GroupDeleteOne struct { + gd *GroupDelete +} + +// Where appends a list predicates to the GroupDelete builder. +func (gdo *GroupDeleteOne) Where(ps ...predicate.Group) *GroupDeleteOne { + gdo.gd.mutation.Where(ps...) + return gdo +} + +// Exec executes the deletion query. +func (gdo *GroupDeleteOne) Exec(ctx context.Context) error { + n, err := gdo.gd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{group.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (gdo *GroupDeleteOne) ExecX(ctx context.Context) { + if err := gdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/group_query.go b/ent/group_query.go new file mode 100644 index 00000000..3646337c --- /dev/null +++ b/ent/group_query.go @@ -0,0 +1,681 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// GroupQuery is the builder for querying Group entities. +type GroupQuery struct { + config + ctx *QueryContext + order []group.OrderOption + inters []Interceptor + predicates []predicate.Group + withUsers *UserQuery + withStoragePolicies *StoragePolicyQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the GroupQuery builder. +func (gq *GroupQuery) Where(ps ...predicate.Group) *GroupQuery { + gq.predicates = append(gq.predicates, ps...) + return gq +} + +// Limit the number of records to be returned by this query. +func (gq *GroupQuery) Limit(limit int) *GroupQuery { + gq.ctx.Limit = &limit + return gq +} + +// Offset to start from. +func (gq *GroupQuery) Offset(offset int) *GroupQuery { + gq.ctx.Offset = &offset + return gq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (gq *GroupQuery) Unique(unique bool) *GroupQuery { + gq.ctx.Unique = &unique + return gq +} + +// Order specifies how the records should be ordered. +func (gq *GroupQuery) Order(o ...group.OrderOption) *GroupQuery { + gq.order = append(gq.order, o...) + return gq +} + +// QueryUsers chains the current query on the "users" edge. +func (gq *GroupQuery) QueryUsers() *UserQuery { + query := (&UserClient{config: gq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := gq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := gq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsersTable, group.UsersColumn), + ) + fromU = sqlgraph.SetNeighbors(gq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryStoragePolicies chains the current query on the "storage_policies" edge. +func (gq *GroupQuery) QueryStoragePolicies() *StoragePolicyQuery { + query := (&StoragePolicyClient{config: gq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := gq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := gq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, selector), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, group.StoragePoliciesTable, group.StoragePoliciesColumn), + ) + fromU = sqlgraph.SetNeighbors(gq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Group entity from the query. +// Returns a *NotFoundError when no Group was found. +func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { + nodes, err := gq.Limit(1).All(setContextOp(ctx, gq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{group.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (gq *GroupQuery) FirstX(ctx context.Context) *Group { + node, err := gq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Group ID from the query. +// Returns a *NotFoundError when no Group ID was found. +func (gq *GroupQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = gq.Limit(1).IDs(setContextOp(ctx, gq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{group.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (gq *GroupQuery) FirstIDX(ctx context.Context) int { + id, err := gq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Group entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Group entity is found. +// Returns a *NotFoundError when no Group entities are found. +func (gq *GroupQuery) Only(ctx context.Context) (*Group, error) { + nodes, err := gq.Limit(2).All(setContextOp(ctx, gq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{group.Label} + default: + return nil, &NotSingularError{group.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (gq *GroupQuery) OnlyX(ctx context.Context) *Group { + node, err := gq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Group ID in the query. +// Returns a *NotSingularError when more than one Group ID is found. +// Returns a *NotFoundError when no entities are found. +func (gq *GroupQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = gq.Limit(2).IDs(setContextOp(ctx, gq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{group.Label} + default: + err = &NotSingularError{group.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (gq *GroupQuery) OnlyIDX(ctx context.Context) int { + id, err := gq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Groups. +func (gq *GroupQuery) All(ctx context.Context) ([]*Group, error) { + ctx = setContextOp(ctx, gq.ctx, "All") + if err := gq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Group, *GroupQuery]() + return withInterceptors[[]*Group](ctx, gq, qr, gq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (gq *GroupQuery) AllX(ctx context.Context) []*Group { + nodes, err := gq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Group IDs. +func (gq *GroupQuery) IDs(ctx context.Context) (ids []int, err error) { + if gq.ctx.Unique == nil && gq.path != nil { + gq.Unique(true) + } + ctx = setContextOp(ctx, gq.ctx, "IDs") + if err = gq.Select(group.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (gq *GroupQuery) IDsX(ctx context.Context) []int { + ids, err := gq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (gq *GroupQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, gq.ctx, "Count") + if err := gq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, gq, querierCount[*GroupQuery](), gq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (gq *GroupQuery) CountX(ctx context.Context) int { + count, err := gq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (gq *GroupQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, gq.ctx, "Exist") + switch _, err := gq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (gq *GroupQuery) ExistX(ctx context.Context) bool { + exist, err := gq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the GroupQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (gq *GroupQuery) Clone() *GroupQuery { + if gq == nil { + return nil + } + return &GroupQuery{ + config: gq.config, + ctx: gq.ctx.Clone(), + order: append([]group.OrderOption{}, gq.order...), + inters: append([]Interceptor{}, gq.inters...), + predicates: append([]predicate.Group{}, gq.predicates...), + withUsers: gq.withUsers.Clone(), + withStoragePolicies: gq.withStoragePolicies.Clone(), + // clone intermediate query. + sql: gq.sql.Clone(), + path: gq.path, + } +} + +// WithUsers tells the query-builder to eager-load the nodes that are connected to +// the "users" edge. The optional arguments are used to configure the query builder of the edge. +func (gq *GroupQuery) WithUsers(opts ...func(*UserQuery)) *GroupQuery { + query := (&UserClient{config: gq.config}).Query() + for _, opt := range opts { + opt(query) + } + gq.withUsers = query + return gq +} + +// WithStoragePolicies tells the query-builder to eager-load the nodes that are connected to +// the "storage_policies" edge. The optional arguments are used to configure the query builder of the edge. +func (gq *GroupQuery) WithStoragePolicies(opts ...func(*StoragePolicyQuery)) *GroupQuery { + query := (&StoragePolicyClient{config: gq.config}).Query() + for _, opt := range opts { + opt(query) + } + gq.withStoragePolicies = query + return gq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Group.Query(). +// GroupBy(group.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { + gq.ctx.Fields = append([]string{field}, fields...) + grbuild := &GroupGroupBy{build: gq} + grbuild.flds = &gq.ctx.Fields + grbuild.label = group.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Group.Query(). +// Select(group.FieldCreatedAt). +// Scan(ctx, &v) +func (gq *GroupQuery) Select(fields ...string) *GroupSelect { + gq.ctx.Fields = append(gq.ctx.Fields, fields...) + sbuild := &GroupSelect{GroupQuery: gq} + sbuild.label = group.Label + sbuild.flds, sbuild.scan = &gq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a GroupSelect configured with the given aggregations. +func (gq *GroupQuery) Aggregate(fns ...AggregateFunc) *GroupSelect { + return gq.Select().Aggregate(fns...) +} + +func (gq *GroupQuery) prepareQuery(ctx context.Context) error { + for _, inter := range gq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, gq); err != nil { + return err + } + } + } + for _, f := range gq.ctx.Fields { + if !group.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if gq.path != nil { + prev, err := gq.path(ctx) + if err != nil { + return err + } + gq.sql = prev + } + return nil +} + +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { + var ( + nodes = []*Group{} + _spec = gq.querySpec() + loadedTypes = [2]bool{ + gq.withUsers != nil, + gq.withStoragePolicies != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Group).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Group{config: gq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := gq.withUsers; query != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { + return nil, err + } + } + if query := gq.withStoragePolicies; query != nil { + if err := gq.loadStoragePolicies(ctx, query, nodes, nil, + func(n *Group, e *StoragePolicy) { n.Edges.StoragePolicies = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(user.FieldGroupUsers) + } + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(group.UsersColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.GroupUsers + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "group_users" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (gq *GroupQuery) loadStoragePolicies(ctx context.Context, query *StoragePolicyQuery, nodes []*Group, init func(*Group), assign func(*Group, *StoragePolicy)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Group) + for i := range nodes { + fk := nodes[i].StoragePolicyID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(storagepolicy.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "storage_policy_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { + _spec := gq.querySpec() + _spec.Node.Columns = gq.ctx.Fields + if len(gq.ctx.Fields) > 0 { + _spec.Unique = gq.ctx.Unique != nil && *gq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, gq.driver, _spec) +} + +func (gq *GroupQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(group.Table, group.Columns, sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt)) + _spec.From = gq.sql + if unique := gq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if gq.path != nil { + _spec.Unique = true + } + if fields := gq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) + for i := range fields { + if fields[i] != group.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if gq.withStoragePolicies != nil { + _spec.Node.AddColumnOnce(group.FieldStoragePolicyID) + } + } + if ps := gq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := gq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := gq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := gq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (gq *GroupQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(gq.driver.Dialect()) + t1 := builder.Table(group.Table) + columns := gq.ctx.Fields + if len(columns) == 0 { + columns = group.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if gq.sql != nil { + selector = gq.sql + selector.Select(selector.Columns(columns...)...) + } + if gq.ctx.Unique != nil && *gq.ctx.Unique { + selector.Distinct() + } + for _, p := range gq.predicates { + p(selector) + } + for _, p := range gq.order { + p(selector) + } + if offset := gq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := gq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// GroupGroupBy is the group-by builder for Group entities. +type GroupGroupBy struct { + selector + build *GroupQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (ggb *GroupGroupBy) Aggregate(fns ...AggregateFunc) *GroupGroupBy { + ggb.fns = append(ggb.fns, fns...) + return ggb +} + +// Scan applies the selector query and scans the result into the given value. +func (ggb *GroupGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ggb.build.ctx, "GroupBy") + if err := ggb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*GroupQuery, *GroupGroupBy](ctx, ggb.build, ggb, ggb.build.inters, v) +} + +func (ggb *GroupGroupBy) sqlScan(ctx context.Context, root *GroupQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(ggb.fns)) + for _, fn := range ggb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*ggb.flds)+len(ggb.fns)) + for _, f := range *ggb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*ggb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ggb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// GroupSelect is the builder for selecting fields of Group entities. +type GroupSelect struct { + *GroupQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (gs *GroupSelect) Aggregate(fns ...AggregateFunc) *GroupSelect { + gs.fns = append(gs.fns, fns...) + return gs +} + +// Scan applies the selector query and scans the result into the given value. +func (gs *GroupSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, gs.ctx, "Select") + if err := gs.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*GroupQuery, *GroupSelect](ctx, gs.GroupQuery, gs, gs.inters, v) +} + +func (gs *GroupSelect) sqlScan(ctx context.Context, root *GroupQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(gs.fns)) + for _, fn := range gs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*gs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := gs.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/group_update.go b/ent/group_update.go new file mode 100644 index 00000000..ae734535 --- /dev/null +++ b/ent/group_update.go @@ -0,0 +1,822 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// GroupUpdate is the builder for updating Group entities. +type GroupUpdate struct { + config + hooks []Hook + mutation *GroupMutation +} + +// Where appends a list predicates to the GroupUpdate builder. +func (gu *GroupUpdate) Where(ps ...predicate.Group) *GroupUpdate { + gu.mutation.Where(ps...) + return gu +} + +// SetUpdatedAt sets the "updated_at" field. +func (gu *GroupUpdate) SetUpdatedAt(t time.Time) *GroupUpdate { + gu.mutation.SetUpdatedAt(t) + return gu +} + +// SetDeletedAt sets the "deleted_at" field. +func (gu *GroupUpdate) SetDeletedAt(t time.Time) *GroupUpdate { + gu.mutation.SetDeletedAt(t) + return gu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (gu *GroupUpdate) SetNillableDeletedAt(t *time.Time) *GroupUpdate { + if t != nil { + gu.SetDeletedAt(*t) + } + return gu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (gu *GroupUpdate) ClearDeletedAt() *GroupUpdate { + gu.mutation.ClearDeletedAt() + return gu +} + +// SetName sets the "name" field. +func (gu *GroupUpdate) SetName(s string) *GroupUpdate { + gu.mutation.SetName(s) + return gu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (gu *GroupUpdate) SetNillableName(s *string) *GroupUpdate { + if s != nil { + gu.SetName(*s) + } + return gu +} + +// SetMaxStorage sets the "max_storage" field. +func (gu *GroupUpdate) SetMaxStorage(i int64) *GroupUpdate { + gu.mutation.ResetMaxStorage() + gu.mutation.SetMaxStorage(i) + return gu +} + +// SetNillableMaxStorage sets the "max_storage" field if the given value is not nil. +func (gu *GroupUpdate) SetNillableMaxStorage(i *int64) *GroupUpdate { + if i != nil { + gu.SetMaxStorage(*i) + } + return gu +} + +// AddMaxStorage adds i to the "max_storage" field. +func (gu *GroupUpdate) AddMaxStorage(i int64) *GroupUpdate { + gu.mutation.AddMaxStorage(i) + return gu +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (gu *GroupUpdate) ClearMaxStorage() *GroupUpdate { + gu.mutation.ClearMaxStorage() + return gu +} + +// SetSpeedLimit sets the "speed_limit" field. +func (gu *GroupUpdate) SetSpeedLimit(i int) *GroupUpdate { + gu.mutation.ResetSpeedLimit() + gu.mutation.SetSpeedLimit(i) + return gu +} + +// SetNillableSpeedLimit sets the "speed_limit" field if the given value is not nil. +func (gu *GroupUpdate) SetNillableSpeedLimit(i *int) *GroupUpdate { + if i != nil { + gu.SetSpeedLimit(*i) + } + return gu +} + +// AddSpeedLimit adds i to the "speed_limit" field. +func (gu *GroupUpdate) AddSpeedLimit(i int) *GroupUpdate { + gu.mutation.AddSpeedLimit(i) + return gu +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (gu *GroupUpdate) ClearSpeedLimit() *GroupUpdate { + gu.mutation.ClearSpeedLimit() + return gu +} + +// SetPermissions sets the "permissions" field. +func (gu *GroupUpdate) SetPermissions(bs *boolset.BooleanSet) *GroupUpdate { + gu.mutation.SetPermissions(bs) + return gu +} + +// SetSettings sets the "settings" field. +func (gu *GroupUpdate) SetSettings(ts *types.GroupSetting) *GroupUpdate { + gu.mutation.SetSettings(ts) + return gu +} + +// ClearSettings clears the value of the "settings" field. +func (gu *GroupUpdate) ClearSettings() *GroupUpdate { + gu.mutation.ClearSettings() + return gu +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (gu *GroupUpdate) SetStoragePolicyID(i int) *GroupUpdate { + gu.mutation.SetStoragePolicyID(i) + return gu +} + +// SetNillableStoragePolicyID sets the "storage_policy_id" field if the given value is not nil. +func (gu *GroupUpdate) SetNillableStoragePolicyID(i *int) *GroupUpdate { + if i != nil { + gu.SetStoragePolicyID(*i) + } + return gu +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (gu *GroupUpdate) ClearStoragePolicyID() *GroupUpdate { + gu.mutation.ClearStoragePolicyID() + return gu +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (gu *GroupUpdate) AddUserIDs(ids ...int) *GroupUpdate { + gu.mutation.AddUserIDs(ids...) + return gu +} + +// AddUsers adds the "users" edges to the User entity. +func (gu *GroupUpdate) AddUsers(u ...*User) *GroupUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gu.AddUserIDs(ids...) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (gu *GroupUpdate) SetStoragePoliciesID(id int) *GroupUpdate { + gu.mutation.SetStoragePoliciesID(id) + return gu +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (gu *GroupUpdate) SetNillableStoragePoliciesID(id *int) *GroupUpdate { + if id != nil { + gu = gu.SetStoragePoliciesID(*id) + } + return gu +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (gu *GroupUpdate) SetStoragePolicies(s *StoragePolicy) *GroupUpdate { + return gu.SetStoragePoliciesID(s.ID) +} + +// Mutation returns the GroupMutation object of the builder. +func (gu *GroupUpdate) Mutation() *GroupMutation { + return gu.mutation +} + +// ClearUsers clears all "users" edges to the User entity. +func (gu *GroupUpdate) ClearUsers() *GroupUpdate { + gu.mutation.ClearUsers() + return gu +} + +// RemoveUserIDs removes the "users" edge to User entities by IDs. +func (gu *GroupUpdate) RemoveUserIDs(ids ...int) *GroupUpdate { + gu.mutation.RemoveUserIDs(ids...) + return gu +} + +// RemoveUsers removes "users" edges to User entities. +func (gu *GroupUpdate) RemoveUsers(u ...*User) *GroupUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gu.RemoveUserIDs(ids...) +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (gu *GroupUpdate) ClearStoragePolicies() *GroupUpdate { + gu.mutation.ClearStoragePolicies() + return gu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (gu *GroupUpdate) Save(ctx context.Context) (int, error) { + if err := gu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, gu.sqlSave, gu.mutation, gu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (gu *GroupUpdate) SaveX(ctx context.Context) int { + affected, err := gu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (gu *GroupUpdate) Exec(ctx context.Context) error { + _, err := gu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (gu *GroupUpdate) ExecX(ctx context.Context) { + if err := gu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (gu *GroupUpdate) defaults() error { + if _, ok := gu.mutation.UpdatedAt(); !ok { + if group.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized group.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := group.UpdateDefaultUpdatedAt() + gu.mutation.SetUpdatedAt(v) + } + return nil +} + +func (gu *GroupUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(group.Table, group.Columns, sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt)) + if ps := gu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := gu.mutation.UpdatedAt(); ok { + _spec.SetField(group.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := gu.mutation.DeletedAt(); ok { + _spec.SetField(group.FieldDeletedAt, field.TypeTime, value) + } + if gu.mutation.DeletedAtCleared() { + _spec.ClearField(group.FieldDeletedAt, field.TypeTime) + } + if value, ok := gu.mutation.Name(); ok { + _spec.SetField(group.FieldName, field.TypeString, value) + } + if value, ok := gu.mutation.MaxStorage(); ok { + _spec.SetField(group.FieldMaxStorage, field.TypeInt64, value) + } + if value, ok := gu.mutation.AddedMaxStorage(); ok { + _spec.AddField(group.FieldMaxStorage, field.TypeInt64, value) + } + if gu.mutation.MaxStorageCleared() { + _spec.ClearField(group.FieldMaxStorage, field.TypeInt64) + } + if value, ok := gu.mutation.SpeedLimit(); ok { + _spec.SetField(group.FieldSpeedLimit, field.TypeInt, value) + } + if value, ok := gu.mutation.AddedSpeedLimit(); ok { + _spec.AddField(group.FieldSpeedLimit, field.TypeInt, value) + } + if gu.mutation.SpeedLimitCleared() { + _spec.ClearField(group.FieldSpeedLimit, field.TypeInt) + } + if value, ok := gu.mutation.Permissions(); ok { + _spec.SetField(group.FieldPermissions, field.TypeBytes, value) + } + if value, ok := gu.mutation.Settings(); ok { + _spec.SetField(group.FieldSettings, field.TypeJSON, value) + } + if gu.mutation.SettingsCleared() { + _spec.ClearField(group.FieldSettings, field.TypeJSON) + } + if gu.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := gu.mutation.RemovedUsersIDs(); len(nodes) > 0 && !gu.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := gu.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if gu.mutation.StoragePoliciesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: group.StoragePoliciesTable, + Columns: []string{group.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := gu.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: group.StoragePoliciesTable, + Columns: []string{group.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, gu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{group.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + gu.mutation.done = true + return n, nil +} + +// GroupUpdateOne is the builder for updating a single Group entity. +type GroupUpdateOne struct { + config + fields []string + hooks []Hook + mutation *GroupMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (guo *GroupUpdateOne) SetUpdatedAt(t time.Time) *GroupUpdateOne { + guo.mutation.SetUpdatedAt(t) + return guo +} + +// SetDeletedAt sets the "deleted_at" field. +func (guo *GroupUpdateOne) SetDeletedAt(t time.Time) *GroupUpdateOne { + guo.mutation.SetDeletedAt(t) + return guo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableDeletedAt(t *time.Time) *GroupUpdateOne { + if t != nil { + guo.SetDeletedAt(*t) + } + return guo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (guo *GroupUpdateOne) ClearDeletedAt() *GroupUpdateOne { + guo.mutation.ClearDeletedAt() + return guo +} + +// SetName sets the "name" field. +func (guo *GroupUpdateOne) SetName(s string) *GroupUpdateOne { + guo.mutation.SetName(s) + return guo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableName(s *string) *GroupUpdateOne { + if s != nil { + guo.SetName(*s) + } + return guo +} + +// SetMaxStorage sets the "max_storage" field. +func (guo *GroupUpdateOne) SetMaxStorage(i int64) *GroupUpdateOne { + guo.mutation.ResetMaxStorage() + guo.mutation.SetMaxStorage(i) + return guo +} + +// SetNillableMaxStorage sets the "max_storage" field if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableMaxStorage(i *int64) *GroupUpdateOne { + if i != nil { + guo.SetMaxStorage(*i) + } + return guo +} + +// AddMaxStorage adds i to the "max_storage" field. +func (guo *GroupUpdateOne) AddMaxStorage(i int64) *GroupUpdateOne { + guo.mutation.AddMaxStorage(i) + return guo +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (guo *GroupUpdateOne) ClearMaxStorage() *GroupUpdateOne { + guo.mutation.ClearMaxStorage() + return guo +} + +// SetSpeedLimit sets the "speed_limit" field. +func (guo *GroupUpdateOne) SetSpeedLimit(i int) *GroupUpdateOne { + guo.mutation.ResetSpeedLimit() + guo.mutation.SetSpeedLimit(i) + return guo +} + +// SetNillableSpeedLimit sets the "speed_limit" field if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableSpeedLimit(i *int) *GroupUpdateOne { + if i != nil { + guo.SetSpeedLimit(*i) + } + return guo +} + +// AddSpeedLimit adds i to the "speed_limit" field. +func (guo *GroupUpdateOne) AddSpeedLimit(i int) *GroupUpdateOne { + guo.mutation.AddSpeedLimit(i) + return guo +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (guo *GroupUpdateOne) ClearSpeedLimit() *GroupUpdateOne { + guo.mutation.ClearSpeedLimit() + return guo +} + +// SetPermissions sets the "permissions" field. +func (guo *GroupUpdateOne) SetPermissions(bs *boolset.BooleanSet) *GroupUpdateOne { + guo.mutation.SetPermissions(bs) + return guo +} + +// SetSettings sets the "settings" field. +func (guo *GroupUpdateOne) SetSettings(ts *types.GroupSetting) *GroupUpdateOne { + guo.mutation.SetSettings(ts) + return guo +} + +// ClearSettings clears the value of the "settings" field. +func (guo *GroupUpdateOne) ClearSettings() *GroupUpdateOne { + guo.mutation.ClearSettings() + return guo +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (guo *GroupUpdateOne) SetStoragePolicyID(i int) *GroupUpdateOne { + guo.mutation.SetStoragePolicyID(i) + return guo +} + +// SetNillableStoragePolicyID sets the "storage_policy_id" field if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableStoragePolicyID(i *int) *GroupUpdateOne { + if i != nil { + guo.SetStoragePolicyID(*i) + } + return guo +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (guo *GroupUpdateOne) ClearStoragePolicyID() *GroupUpdateOne { + guo.mutation.ClearStoragePolicyID() + return guo +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (guo *GroupUpdateOne) AddUserIDs(ids ...int) *GroupUpdateOne { + guo.mutation.AddUserIDs(ids...) + return guo +} + +// AddUsers adds the "users" edges to the User entity. +func (guo *GroupUpdateOne) AddUsers(u ...*User) *GroupUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return guo.AddUserIDs(ids...) +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID. +func (guo *GroupUpdateOne) SetStoragePoliciesID(id int) *GroupUpdateOne { + guo.mutation.SetStoragePoliciesID(id) + return guo +} + +// SetNillableStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by ID if the given value is not nil. +func (guo *GroupUpdateOne) SetNillableStoragePoliciesID(id *int) *GroupUpdateOne { + if id != nil { + guo = guo.SetStoragePoliciesID(*id) + } + return guo +} + +// SetStoragePolicies sets the "storage_policies" edge to the StoragePolicy entity. +func (guo *GroupUpdateOne) SetStoragePolicies(s *StoragePolicy) *GroupUpdateOne { + return guo.SetStoragePoliciesID(s.ID) +} + +// Mutation returns the GroupMutation object of the builder. +func (guo *GroupUpdateOne) Mutation() *GroupMutation { + return guo.mutation +} + +// ClearUsers clears all "users" edges to the User entity. +func (guo *GroupUpdateOne) ClearUsers() *GroupUpdateOne { + guo.mutation.ClearUsers() + return guo +} + +// RemoveUserIDs removes the "users" edge to User entities by IDs. +func (guo *GroupUpdateOne) RemoveUserIDs(ids ...int) *GroupUpdateOne { + guo.mutation.RemoveUserIDs(ids...) + return guo +} + +// RemoveUsers removes "users" edges to User entities. +func (guo *GroupUpdateOne) RemoveUsers(u ...*User) *GroupUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return guo.RemoveUserIDs(ids...) +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (guo *GroupUpdateOne) ClearStoragePolicies() *GroupUpdateOne { + guo.mutation.ClearStoragePolicies() + return guo +} + +// Where appends a list predicates to the GroupUpdate builder. +func (guo *GroupUpdateOne) Where(ps ...predicate.Group) *GroupUpdateOne { + guo.mutation.Where(ps...) + return guo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (guo *GroupUpdateOne) Select(field string, fields ...string) *GroupUpdateOne { + guo.fields = append([]string{field}, fields...) + return guo +} + +// Save executes the query and returns the updated Group entity. +func (guo *GroupUpdateOne) Save(ctx context.Context) (*Group, error) { + if err := guo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, guo.sqlSave, guo.mutation, guo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (guo *GroupUpdateOne) SaveX(ctx context.Context) *Group { + node, err := guo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (guo *GroupUpdateOne) Exec(ctx context.Context) error { + _, err := guo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (guo *GroupUpdateOne) ExecX(ctx context.Context) { + if err := guo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (guo *GroupUpdateOne) defaults() error { + if _, ok := guo.mutation.UpdatedAt(); !ok { + if group.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized group.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := group.UpdateDefaultUpdatedAt() + guo.mutation.SetUpdatedAt(v) + } + return nil +} + +func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) { + _spec := sqlgraph.NewUpdateSpec(group.Table, group.Columns, sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt)) + id, ok := guo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Group.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := guo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, group.FieldID) + for _, f := range fields { + if !group.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != group.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := guo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := guo.mutation.UpdatedAt(); ok { + _spec.SetField(group.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := guo.mutation.DeletedAt(); ok { + _spec.SetField(group.FieldDeletedAt, field.TypeTime, value) + } + if guo.mutation.DeletedAtCleared() { + _spec.ClearField(group.FieldDeletedAt, field.TypeTime) + } + if value, ok := guo.mutation.Name(); ok { + _spec.SetField(group.FieldName, field.TypeString, value) + } + if value, ok := guo.mutation.MaxStorage(); ok { + _spec.SetField(group.FieldMaxStorage, field.TypeInt64, value) + } + if value, ok := guo.mutation.AddedMaxStorage(); ok { + _spec.AddField(group.FieldMaxStorage, field.TypeInt64, value) + } + if guo.mutation.MaxStorageCleared() { + _spec.ClearField(group.FieldMaxStorage, field.TypeInt64) + } + if value, ok := guo.mutation.SpeedLimit(); ok { + _spec.SetField(group.FieldSpeedLimit, field.TypeInt, value) + } + if value, ok := guo.mutation.AddedSpeedLimit(); ok { + _spec.AddField(group.FieldSpeedLimit, field.TypeInt, value) + } + if guo.mutation.SpeedLimitCleared() { + _spec.ClearField(group.FieldSpeedLimit, field.TypeInt) + } + if value, ok := guo.mutation.Permissions(); ok { + _spec.SetField(group.FieldPermissions, field.TypeBytes, value) + } + if value, ok := guo.mutation.Settings(); ok { + _spec.SetField(group.FieldSettings, field.TypeJSON, value) + } + if guo.mutation.SettingsCleared() { + _spec.ClearField(group.FieldSettings, field.TypeJSON) + } + if guo.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := guo.mutation.RemovedUsersIDs(); len(nodes) > 0 && !guo.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := guo.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsersTable, + Columns: []string{group.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if guo.mutation.StoragePoliciesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: group.StoragePoliciesTable, + Columns: []string{group.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := guo.mutation.StoragePoliciesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: group.StoragePoliciesTable, + Columns: []string{group.StoragePoliciesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Group{config: guo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, guo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{group.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + guo.mutation.done = true + return _node, nil +} diff --git a/ent/hook/hook.go b/ent/hook/hook.go new file mode 100644 index 00000000..29b3f1fc --- /dev/null +++ b/ent/hook/hook.go @@ -0,0 +1,343 @@ +// Code generated by ent, DO NOT EDIT. + +package hook + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/ent" +) + +// The DavAccountFunc type is an adapter to allow the use of ordinary +// function as DavAccount mutator. +type DavAccountFunc func(context.Context, *ent.DavAccountMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f DavAccountFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.DavAccountMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DavAccountMutation", m) +} + +// The DirectLinkFunc type is an adapter to allow the use of ordinary +// function as DirectLink mutator. +type DirectLinkFunc func(context.Context, *ent.DirectLinkMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f DirectLinkFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.DirectLinkMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DirectLinkMutation", m) +} + +// The EntityFunc type is an adapter to allow the use of ordinary +// function as Entity mutator. +type EntityFunc func(context.Context, *ent.EntityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f EntityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.EntityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EntityMutation", m) +} + +// The FileFunc type is an adapter to allow the use of ordinary +// function as File mutator. +type FileFunc func(context.Context, *ent.FileMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f FileFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.FileMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.FileMutation", m) +} + +// The GroupFunc type is an adapter to allow the use of ordinary +// function as Group mutator. +type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.GroupMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m) +} + +// The MetadataFunc type is an adapter to allow the use of ordinary +// function as Metadata mutator. +type MetadataFunc func(context.Context, *ent.MetadataMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f MetadataFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.MetadataMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetadataMutation", m) +} + +// The NodeFunc type is an adapter to allow the use of ordinary +// function as Node mutator. +type NodeFunc func(context.Context, *ent.NodeMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f NodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.NodeMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.NodeMutation", m) +} + +// The PasskeyFunc type is an adapter to allow the use of ordinary +// function as Passkey mutator. +type PasskeyFunc func(context.Context, *ent.PasskeyMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PasskeyFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PasskeyMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PasskeyMutation", m) +} + +// The SettingFunc type is an adapter to allow the use of ordinary +// function as Setting mutator. +type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SettingMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) +} + +// The ShareFunc type is an adapter to allow the use of ordinary +// function as Share mutator. +type ShareFunc func(context.Context, *ent.ShareMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ShareFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ShareMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ShareMutation", m) +} + +// The StoragePolicyFunc type is an adapter to allow the use of ordinary +// function as StoragePolicy mutator. +type StoragePolicyFunc func(context.Context, *ent.StoragePolicyMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f StoragePolicyFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.StoragePolicyMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.StoragePolicyMutation", m) +} + +// The TaskFunc type is an adapter to allow the use of ordinary +// function as Task mutator. +type TaskFunc func(context.Context, *ent.TaskMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f TaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.TaskMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TaskMutation", m) +} + +// The UserFunc type is an adapter to allow the use of ordinary +// function as User mutator. +type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UserFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UserMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserMutation", m) +} + +// Condition is a hook condition function. +type Condition func(context.Context, ent.Mutation) bool + +// And groups conditions with the AND operator. +func And(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + if !first(ctx, m) || !second(ctx, m) { + return false + } + for _, cond := range rest { + if !cond(ctx, m) { + return false + } + } + return true + } +} + +// Or groups conditions with the OR operator. +func Or(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + if first(ctx, m) || second(ctx, m) { + return true + } + for _, cond := range rest { + if cond(ctx, m) { + return true + } + } + return false + } +} + +// Not negates a given condition. +func Not(cond Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + return !cond(ctx, m) + } +} + +// HasOp is a condition testing mutation operation. +func HasOp(op ent.Op) Condition { + return func(_ context.Context, m ent.Mutation) bool { + return m.Op().Is(op) + } +} + +// HasAddedFields is a condition validating `.AddedField` on fields. +func HasAddedFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if _, exists := m.AddedField(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.AddedField(field); !exists { + return false + } + } + return true + } +} + +// HasClearedFields is a condition validating `.FieldCleared` on fields. +func HasClearedFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if exists := m.FieldCleared(field); !exists { + return false + } + for _, field := range fields { + if exists := m.FieldCleared(field); !exists { + return false + } + } + return true + } +} + +// HasFields is a condition validating `.Field` on fields. +func HasFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if _, exists := m.Field(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.Field(field); !exists { + return false + } + } + return true + } +} + +// If executes the given hook under condition. +// +// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) +func If(hk ent.Hook, cond Condition) ent.Hook { + return func(next ent.Mutator) ent.Mutator { + return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if cond(ctx, m) { + return hk(next).Mutate(ctx, m) + } + return next.Mutate(ctx, m) + }) + } +} + +// On executes the given hook only for the given operation. +// +// hook.On(Log, ent.Delete|ent.Create) +func On(hk ent.Hook, op ent.Op) ent.Hook { + return If(hk, HasOp(op)) +} + +// Unless skips the given hook only for the given operation. +// +// hook.Unless(Log, ent.Update|ent.UpdateOne) +func Unless(hk ent.Hook, op ent.Op) ent.Hook { + return If(hk, Not(HasOp(op))) +} + +// FixedError is a hook returning a fixed error. +func FixedError(err error) ent.Hook { + return func(ent.Mutator) ent.Mutator { + return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { + return nil, err + }) + } +} + +// Reject returns a hook that rejects all operations that match op. +// +// func (T) Hooks() []ent.Hook { +// return []ent.Hook{ +// Reject(ent.Delete|ent.Update), +// } +// } +func Reject(op ent.Op) ent.Hook { + hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) + return On(hk, op) +} + +// Chain acts as a list of hooks and is effectively immutable. +// Once created, it will always hold the same set of hooks in the same order. +type Chain struct { + hooks []ent.Hook +} + +// NewChain creates a new chain of hooks. +func NewChain(hooks ...ent.Hook) Chain { + return Chain{append([]ent.Hook(nil), hooks...)} +} + +// Hook chains the list of hooks and returns the final hook. +func (c Chain) Hook() ent.Hook { + return func(mutator ent.Mutator) ent.Mutator { + for i := len(c.hooks) - 1; i >= 0; i-- { + mutator = c.hooks[i](mutator) + } + return mutator + } +} + +// Append extends a chain, adding the specified hook +// as the last ones in the mutation flow. +func (c Chain) Append(hooks ...ent.Hook) Chain { + newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) + newHooks = append(newHooks, c.hooks...) + newHooks = append(newHooks, hooks...) + return Chain{newHooks} +} + +// Extend extends a chain, adding the specified chain +// as the last ones in the mutation flow. +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.hooks...) +} diff --git a/ent/intercept/intercept.go b/ent/intercept/intercept.go new file mode 100644 index 00000000..da350367 --- /dev/null +++ b/ent/intercept/intercept.go @@ -0,0 +1,509 @@ +// Code generated by ent, DO NOT EDIT. + +package intercept + +import ( + "context" + "fmt" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// The Query interface represents an operation that queries a graph. +// By using this interface, users can write generic code that manipulates +// query builders of different types. +type Query interface { + // Type returns the string representation of the query type. + Type() string + // Limit the number of records to be returned by this query. + Limit(int) + // Offset to start from. + Offset(int) + // Unique configures the query builder to filter duplicate records. + Unique(bool) + // Order specifies how the records should be ordered. + Order(...func(*sql.Selector)) + // WhereP appends storage-level predicates to the query builder. Using this method, users + // can use type-assertion to append predicates that do not depend on any generated package. + WhereP(...func(*sql.Selector)) +} + +// The Func type is an adapter that allows ordinary functions to be used as interceptors. +// Unlike traversal functions, interceptors are skipped during graph traversals. Note that the +// implementation of Func is different from the one defined in entgo.io/ent.InterceptFunc. +type Func func(context.Context, Query) error + +// Intercept calls f(ctx, q) and then applied the next Querier. +func (f Func) Intercept(next ent.Querier) ent.Querier { + return ent.QuerierFunc(func(ctx context.Context, q ent.Query) (ent.Value, error) { + query, err := NewQuery(q) + if err != nil { + return nil, err + } + if err := f(ctx, query); err != nil { + return nil, err + } + return next.Query(ctx, q) + }) +} + +// The TraverseFunc type is an adapter to allow the use of ordinary function as Traverser. +// If f is a function with the appropriate signature, TraverseFunc(f) is a Traverser that calls f. +type TraverseFunc func(context.Context, Query) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseFunc) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseFunc) Traverse(ctx context.Context, q ent.Query) error { + query, err := NewQuery(q) + if err != nil { + return err + } + return f(ctx, query) +} + +// The DavAccountFunc type is an adapter to allow the use of ordinary function as a Querier. +type DavAccountFunc func(context.Context, *ent.DavAccountQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f DavAccountFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.DavAccountQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.DavAccountQuery", q) +} + +// The TraverseDavAccount type is an adapter to allow the use of ordinary function as Traverser. +type TraverseDavAccount func(context.Context, *ent.DavAccountQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseDavAccount) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseDavAccount) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.DavAccountQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.DavAccountQuery", q) +} + +// The DirectLinkFunc type is an adapter to allow the use of ordinary function as a Querier. +type DirectLinkFunc func(context.Context, *ent.DirectLinkQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f DirectLinkFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.DirectLinkQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.DirectLinkQuery", q) +} + +// The TraverseDirectLink type is an adapter to allow the use of ordinary function as Traverser. +type TraverseDirectLink func(context.Context, *ent.DirectLinkQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseDirectLink) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseDirectLink) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.DirectLinkQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.DirectLinkQuery", q) +} + +// The EntityFunc type is an adapter to allow the use of ordinary function as a Querier. +type EntityFunc func(context.Context, *ent.EntityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f EntityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.EntityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.EntityQuery", q) +} + +// The TraverseEntity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseEntity func(context.Context, *ent.EntityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseEntity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseEntity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.EntityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.EntityQuery", q) +} + +// The FileFunc type is an adapter to allow the use of ordinary function as a Querier. +type FileFunc func(context.Context, *ent.FileQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f FileFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.FileQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.FileQuery", q) +} + +// The TraverseFile type is an adapter to allow the use of ordinary function as Traverser. +type TraverseFile func(context.Context, *ent.FileQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseFile) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseFile) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.FileQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.FileQuery", q) +} + +// The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. +type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f GroupFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.GroupQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) +} + +// The TraverseGroup type is an adapter to allow the use of ordinary function as Traverser. +type TraverseGroup func(context.Context, *ent.GroupQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseGroup) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.GroupQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) +} + +// The MetadataFunc type is an adapter to allow the use of ordinary function as a Querier. +type MetadataFunc func(context.Context, *ent.MetadataQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f MetadataFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.MetadataQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.MetadataQuery", q) +} + +// The TraverseMetadata type is an adapter to allow the use of ordinary function as Traverser. +type TraverseMetadata func(context.Context, *ent.MetadataQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseMetadata) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseMetadata) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.MetadataQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.MetadataQuery", q) +} + +// The NodeFunc type is an adapter to allow the use of ordinary function as a Querier. +type NodeFunc func(context.Context, *ent.NodeQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f NodeFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.NodeQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.NodeQuery", q) +} + +// The TraverseNode type is an adapter to allow the use of ordinary function as Traverser. +type TraverseNode func(context.Context, *ent.NodeQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseNode) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseNode) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.NodeQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.NodeQuery", q) +} + +// The PasskeyFunc type is an adapter to allow the use of ordinary function as a Querier. +type PasskeyFunc func(context.Context, *ent.PasskeyQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PasskeyFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PasskeyQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PasskeyQuery", q) +} + +// The TraversePasskey type is an adapter to allow the use of ordinary function as Traverser. +type TraversePasskey func(context.Context, *ent.PasskeyQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePasskey) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePasskey) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PasskeyQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PasskeyQuery", q) +} + +// The SettingFunc type is an adapter to allow the use of ordinary function as a Querier. +type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SettingFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SettingQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) +} + +// The TraverseSetting type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSetting func(context.Context, *ent.SettingQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSetting) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SettingQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) +} + +// The ShareFunc type is an adapter to allow the use of ordinary function as a Querier. +type ShareFunc func(context.Context, *ent.ShareQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ShareFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ShareQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ShareQuery", q) +} + +// The TraverseShare type is an adapter to allow the use of ordinary function as Traverser. +type TraverseShare func(context.Context, *ent.ShareQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseShare) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseShare) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ShareQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ShareQuery", q) +} + +// The StoragePolicyFunc type is an adapter to allow the use of ordinary function as a Querier. +type StoragePolicyFunc func(context.Context, *ent.StoragePolicyQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f StoragePolicyFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.StoragePolicyQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.StoragePolicyQuery", q) +} + +// The TraverseStoragePolicy type is an adapter to allow the use of ordinary function as Traverser. +type TraverseStoragePolicy func(context.Context, *ent.StoragePolicyQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseStoragePolicy) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseStoragePolicy) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.StoragePolicyQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.StoragePolicyQuery", q) +} + +// The TaskFunc type is an adapter to allow the use of ordinary function as a Querier. +type TaskFunc func(context.Context, *ent.TaskQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f TaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.TaskQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.TaskQuery", q) +} + +// The TraverseTask type is an adapter to allow the use of ordinary function as Traverser. +type TraverseTask func(context.Context, *ent.TaskQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseTask) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseTask) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.TaskQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.TaskQuery", q) +} + +// The UserFunc type is an adapter to allow the use of ordinary function as a Querier. +type UserFunc func(context.Context, *ent.UserQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UserFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UserQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserQuery", q) +} + +// The TraverseUser type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUser func(context.Context, *ent.UserQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUser) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUser) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UserQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UserQuery", q) +} + +// NewQuery returns the generic Query interface for the given typed query. +func NewQuery(q ent.Query) (Query, error) { + switch q := q.(type) { + case *ent.DavAccountQuery: + return &query[*ent.DavAccountQuery, predicate.DavAccount, davaccount.OrderOption]{typ: ent.TypeDavAccount, tq: q}, nil + case *ent.DirectLinkQuery: + return &query[*ent.DirectLinkQuery, predicate.DirectLink, directlink.OrderOption]{typ: ent.TypeDirectLink, tq: q}, nil + case *ent.EntityQuery: + return &query[*ent.EntityQuery, predicate.Entity, entity.OrderOption]{typ: ent.TypeEntity, tq: q}, nil + case *ent.FileQuery: + return &query[*ent.FileQuery, predicate.File, file.OrderOption]{typ: ent.TypeFile, tq: q}, nil + case *ent.GroupQuery: + return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil + case *ent.MetadataQuery: + return &query[*ent.MetadataQuery, predicate.Metadata, metadata.OrderOption]{typ: ent.TypeMetadata, tq: q}, nil + case *ent.NodeQuery: + return &query[*ent.NodeQuery, predicate.Node, node.OrderOption]{typ: ent.TypeNode, tq: q}, nil + case *ent.PasskeyQuery: + return &query[*ent.PasskeyQuery, predicate.Passkey, passkey.OrderOption]{typ: ent.TypePasskey, tq: q}, nil + case *ent.SettingQuery: + return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.ShareQuery: + return &query[*ent.ShareQuery, predicate.Share, share.OrderOption]{typ: ent.TypeShare, tq: q}, nil + case *ent.StoragePolicyQuery: + return &query[*ent.StoragePolicyQuery, predicate.StoragePolicy, storagepolicy.OrderOption]{typ: ent.TypeStoragePolicy, tq: q}, nil + case *ent.TaskQuery: + return &query[*ent.TaskQuery, predicate.Task, task.OrderOption]{typ: ent.TypeTask, tq: q}, nil + case *ent.UserQuery: + return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil + default: + return nil, fmt.Errorf("unknown query type %T", q) + } +} + +type query[T any, P ~func(*sql.Selector), R ~func(*sql.Selector)] struct { + typ string + tq interface { + Limit(int) T + Offset(int) T + Unique(bool) T + Order(...R) T + Where(...P) T + } +} + +func (q query[T, P, R]) Type() string { + return q.typ +} + +func (q query[T, P, R]) Limit(limit int) { + q.tq.Limit(limit) +} + +func (q query[T, P, R]) Offset(offset int) { + q.tq.Offset(offset) +} + +func (q query[T, P, R]) Unique(unique bool) { + q.tq.Unique(unique) +} + +func (q query[T, P, R]) Order(orders ...func(*sql.Selector)) { + rs := make([]R, len(orders)) + for i := range orders { + rs[i] = orders[i] + } + q.tq.Order(rs...) +} + +func (q query[T, P, R]) WhereP(ps ...func(*sql.Selector)) { + p := make([]P, len(ps)) + for i := range ps { + p[i] = ps[i] + } + q.tq.Where(p...) +} diff --git a/ent/internal/schema.go b/ent/internal/schema.go new file mode 100644 index 00000000..e639bdfd --- /dev/null +++ b/ent/internal/schema.go @@ -0,0 +1,9 @@ +// Code generated by ent, DO NOT EDIT. + +//go:build tools +// +build tools + +// Package internal holds a loadable version of the latest schema. +package internal + +const Schema = "{\"Schema\":\"github.com/cloudreve/Cloudreve/v4/ent/schema\",\"Package\":\"github.com/cloudreve/Cloudreve/v4/ent\",\"Schemas\":[{\"name\":\"DavAccount\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"owner\",\"type\":\"User\",\"field\":\"owner_id\",\"ref_name\":\"dav_accounts\",\"unique\":true,\"inverse\":true,\"required\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"uri\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"password\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0},\"sensitive\":true},{\"name\":\"options\",\"type\":{\"Type\":5,\"Ident\":\"*boolset.BooleanSet\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"PkgName\":\"boolset\",\"Nillable\":true,\"RType\":{\"Name\":\"BooleanSet\",\"Ident\":\"boolset.BooleanSet\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"Methods\":{\"Enabled\":{\"In\":[{\"Name\":\"int\",\"Ident\":\"int\",\"Kind\":2,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"bool\",\"Ident\":\"bool\",\"Kind\":1,\"PkgPath\":\"\",\"Methods\":null}]},\"MarshalBinary\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Scan\":{\"In\":[{\"Name\":\"\",\"Ident\":\"interface {}\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"String\":{\"In\":[],\"Out\":[{\"Name\":\"string\",\"Ident\":\"string\",\"Kind\":24,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalBinary\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Value\":{\"In\":[],\"Out\":[{\"Name\":\"Value\",\"Ident\":\"driver.Value\",\"Kind\":20,\"PkgPath\":\"database/sql/driver\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"props\",\"type\":{\"Type\":3,\"Ident\":\"*types.DavAccountProps\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"DavAccountProps\",\"Ident\":\"types.DavAccountProps\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"owner_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}}],\"indexes\":[{\"unique\":true,\"fields\":[\"owner_id\",\"password\"]}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"DirectLink\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"file\",\"type\":\"File\",\"field\":\"file_id\",\"ref_name\":\"direct_links\",\"unique\":true,\"inverse\":true,\"required\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"downloads\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"file_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"speed\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Entity\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"file\",\"type\":\"File\",\"ref_name\":\"entities\",\"inverse\":true},{\"name\":\"user\",\"type\":\"User\",\"field\":\"created_by\",\"ref_name\":\"entities\",\"unique\":true,\"inverse\":true},{\"name\":\"storage_policy\",\"type\":\"StoragePolicy\",\"field\":\"storage_policy_entities\",\"ref_name\":\"entities\",\"unique\":true,\"inverse\":true,\"required\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"type\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"source\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"size\",\"type\":{\"Type\":13,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"reference_count\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":1,\"default_kind\":2,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"storage_policy_entities\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"created_by\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"upload_session_id\",\"type\":{\"Type\":4,\"Ident\":\"uuid.UUID\",\"PkgPath\":\"github.com/gofrs/uuid\",\"PkgName\":\"uuid\",\"Nillable\":false,\"RType\":{\"Name\":\"UUID\",\"Ident\":\"uuid.UUID\",\"Kind\":17,\"PkgPath\":\"github.com/gofrs/uuid\",\"Methods\":{\"Bytes\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}]},\"Format\":{\"In\":[{\"Name\":\"State\",\"Ident\":\"fmt.State\",\"Kind\":20,\"PkgPath\":\"fmt\",\"Methods\":null},{\"Name\":\"int32\",\"Ident\":\"int32\",\"Kind\":5,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"MarshalBinary\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"MarshalText\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Scan\":{\"In\":[{\"Name\":\"\",\"Ident\":\"interface {}\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"SetVariant\":{\"In\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"SetVersion\":{\"In\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"String\":{\"In\":[],\"Out\":[{\"Name\":\"string\",\"Ident\":\"string\",\"Kind\":24,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalBinary\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalText\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Value\":{\"In\":[],\"Out\":[{\"Name\":\"Value\",\"Ident\":\"driver.Value\",\"Kind\":20,\"PkgPath\":\"database/sql/driver\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Variant\":{\"In\":[],\"Out\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}]},\"Version\":{\"In\":[],\"Out\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":6,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"recycle_options\",\"type\":{\"Type\":3,\"Ident\":\"*types.EntityRecycleOption\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"EntityRecycleOption\",\"Ident\":\"types.EntityRecycleOption\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"position\":{\"Index\":7,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"File\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"owner\",\"type\":\"User\",\"field\":\"owner_id\",\"ref_name\":\"files\",\"unique\":true,\"inverse\":true,\"required\":true},{\"name\":\"storage_policies\",\"type\":\"StoragePolicy\",\"field\":\"storage_policy_files\",\"ref_name\":\"files\",\"unique\":true,\"inverse\":true},{\"name\":\"parent\",\"type\":\"File\",\"field\":\"file_children\",\"ref\":{\"name\":\"children\",\"type\":\"File\"},\"unique\":true,\"inverse\":true},{\"name\":\"metadata\",\"type\":\"Metadata\"},{\"name\":\"entities\",\"type\":\"Entity\"},{\"name\":\"shares\",\"type\":\"Share\"},{\"name\":\"direct_links\",\"type\":\"DirectLink\"}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"type\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"owner_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"size\",\"type\":{\"Type\":13,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":0,\"default_kind\":6,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"primary_entity\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"file_children\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"is_symbolic\",\"type\":{\"Type\":1,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":false,\"default_kind\":1,\"position\":{\"Index\":6,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"props\",\"type\":{\"Type\":3,\"Ident\":\"*types.FileProps\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"FileProps\",\"Ident\":\"types.FileProps\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"position\":{\"Index\":7,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"storage_policy_files\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":8,\"MixedIn\":false,\"MixinIndex\":0}}],\"indexes\":[{\"unique\":true,\"fields\":[\"file_children\",\"name\"]},{\"fields\":[\"file_children\",\"type\",\"updated_at\"]},{\"fields\":[\"file_children\",\"type\",\"size\"]}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Group\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"users\",\"type\":\"User\"},{\"name\":\"storage_policies\",\"type\":\"StoragePolicy\",\"field\":\"storage_policy_id\",\"ref_name\":\"groups\",\"unique\":true,\"inverse\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"max_storage\",\"type\":{\"Type\":13,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"speed_limit\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"permissions\",\"type\":{\"Type\":5,\"Ident\":\"*boolset.BooleanSet\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"PkgName\":\"boolset\",\"Nillable\":true,\"RType\":{\"Name\":\"BooleanSet\",\"Ident\":\"boolset.BooleanSet\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"Methods\":{\"Enabled\":{\"In\":[{\"Name\":\"int\",\"Ident\":\"int\",\"Kind\":2,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"bool\",\"Ident\":\"bool\",\"Kind\":1,\"PkgPath\":\"\",\"Methods\":null}]},\"MarshalBinary\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Scan\":{\"In\":[{\"Name\":\"\",\"Ident\":\"interface {}\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"String\":{\"In\":[],\"Out\":[{\"Name\":\"string\",\"Ident\":\"string\",\"Kind\":24,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalBinary\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Value\":{\"In\":[],\"Out\":[{\"Name\":\"Value\",\"Ident\":\"driver.Value\",\"Kind\":20,\"PkgPath\":\"database/sql/driver\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"settings\",\"type\":{\"Type\":3,\"Ident\":\"*types.GroupSetting\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"GroupSetting\",\"Ident\":\"types.GroupSetting\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"default\":true,\"default_value\":{},\"default_kind\":22,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"storage_policy_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Metadata\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"file\",\"type\":\"File\",\"field\":\"file_id\",\"ref_name\":\"metadata\",\"unique\":true,\"inverse\":true,\"required\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"value\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"file_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"is_public\",\"type\":{\"Type\":1,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":false,\"default_kind\":1,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}}],\"indexes\":[{\"unique\":true,\"fields\":[\"file_id\",\"name\"]}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Node\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"storage_policy\",\"type\":\"StoragePolicy\"}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"status\",\"type\":{\"Type\":6,\"Ident\":\"node.Status\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"enums\":[{\"N\":\"active\",\"V\":\"active\"},{\"N\":\"suspended\",\"V\":\"suspended\"}],\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"type\",\"type\":{\"Type\":6,\"Ident\":\"node.Type\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"enums\":[{\"N\":\"master\",\"V\":\"master\"},{\"N\":\"slave\",\"V\":\"slave\"}],\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"server\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"slave_key\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"capabilities\",\"type\":{\"Type\":5,\"Ident\":\"*boolset.BooleanSet\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"PkgName\":\"boolset\",\"Nillable\":true,\"RType\":{\"Name\":\"BooleanSet\",\"Ident\":\"boolset.BooleanSet\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/pkg/boolset\",\"Methods\":{\"Enabled\":{\"In\":[{\"Name\":\"int\",\"Ident\":\"int\",\"Kind\":2,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"bool\",\"Ident\":\"bool\",\"Kind\":1,\"PkgPath\":\"\",\"Methods\":null}]},\"MarshalBinary\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Scan\":{\"In\":[{\"Name\":\"\",\"Ident\":\"interface {}\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"String\":{\"In\":[],\"Out\":[{\"Name\":\"string\",\"Ident\":\"string\",\"Kind\":24,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalBinary\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Value\":{\"In\":[],\"Out\":[{\"Name\":\"Value\",\"Ident\":\"driver.Value\",\"Kind\":20,\"PkgPath\":\"database/sql/driver\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"settings\",\"type\":{\"Type\":3,\"Ident\":\"*types.NodeSetting\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"NodeSetting\",\"Ident\":\"types.NodeSetting\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"default\":true,\"default_value\":{},\"default_kind\":22,\"position\":{\"Index\":6,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"weight\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":0,\"default_kind\":2,\"position\":{\"Index\":7,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Passkey\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"user\",\"type\":\"User\",\"field\":\"user_id\",\"ref_name\":\"passkey\",\"unique\":true,\"inverse\":true,\"required\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"user_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"credential_id\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"credential\",\"type\":{\"Type\":3,\"Ident\":\"*webauthn.Credential\",\"PkgPath\":\"github.com/go-webauthn/webauthn/webauthn\",\"PkgName\":\"webauthn\",\"Nillable\":true,\"RType\":{\"Name\":\"Credential\",\"Ident\":\"webauthn.Credential\",\"Kind\":22,\"PkgPath\":\"github.com/go-webauthn/webauthn/webauthn\",\"Methods\":{\"Descriptor\":{\"In\":[],\"Out\":[{\"Name\":\"CredentialDescriptor\",\"Ident\":\"protocol.CredentialDescriptor\",\"Kind\":25,\"PkgPath\":\"github.com/go-webauthn/webauthn/protocol\",\"Methods\":null}]},\"Verify\":{\"In\":[{\"Name\":\"Provider\",\"Ident\":\"metadata.Provider\",\"Kind\":20,\"PkgPath\":\"github.com/go-webauthn/webauthn/metadata\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0},\"sensitive\":true},{\"name\":\"used_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}}],\"indexes\":[{\"unique\":true,\"fields\":[\"user_id\",\"credential_id\"]}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Setting\",\"config\":{\"Table\":\"\"},\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"unique\":true,\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"value\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"optional\":true,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Share\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"user\",\"type\":\"User\",\"ref_name\":\"shares\",\"unique\":true,\"inverse\":true},{\"name\":\"file\",\"type\":\"File\",\"ref_name\":\"shares\",\"unique\":true,\"inverse\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"password\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"views\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":0,\"default_kind\":2,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"downloads\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":0,\"default_kind\":2,\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"expires\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"remain_downloads\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"StoragePolicy\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"users\",\"type\":\"User\"},{\"name\":\"groups\",\"type\":\"Group\"},{\"name\":\"files\",\"type\":\"File\"},{\"name\":\"entities\",\"type\":\"Entity\"},{\"name\":\"node\",\"type\":\"Node\",\"field\":\"node_id\",\"ref_name\":\"storage_policy\",\"unique\":true,\"inverse\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"type\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"server\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"bucket_name\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"is_private\",\"type\":{\"Type\":1,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"access_key\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"secret_key\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"optional\":true,\"position\":{\"Index\":6,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"max_size\",\"type\":{\"Type\":13,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":7,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"dir_name_rule\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":8,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"file_name_rule\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":9,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"settings\",\"type\":{\"Type\":3,\"Ident\":\"*types.PolicySetting\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"PolicySetting\",\"Ident\":\"types.PolicySetting\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"default\":true,\"default_value\":{\"file_type\":null,\"native_media_processing\":false,\"s3_path_style\":false,\"token\":\"\"},\"default_kind\":22,\"position\":{\"Index\":10,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"node_id\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":11,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"Task\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"user\",\"type\":\"User\",\"field\":\"user_tasks\",\"ref_name\":\"tasks\",\"unique\":true,\"inverse\":true}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"type\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"status\",\"type\":{\"Type\":6,\"Ident\":\"task.Status\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"enums\":[{\"N\":\"queued\",\"V\":\"queued\"},{\"N\":\"processing\",\"V\":\"processing\"},{\"N\":\"suspending\",\"V\":\"suspending\"},{\"N\":\"error\",\"V\":\"error\"},{\"N\":\"canceled\",\"V\":\"canceled\"},{\"N\":\"completed\",\"V\":\"completed\"}],\"default\":true,\"default_value\":\"queued\",\"default_kind\":24,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"public_state\",\"type\":{\"Type\":3,\"Ident\":\"*types.TaskPublicState\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"TaskPublicState\",\"Ident\":\"types.TaskPublicState\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"private_state\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":2147483647,\"optional\":true,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"correlation_id\",\"type\":{\"Type\":4,\"Ident\":\"uuid.UUID\",\"PkgPath\":\"github.com/gofrs/uuid\",\"PkgName\":\"uuid\",\"Nillable\":false,\"RType\":{\"Name\":\"UUID\",\"Ident\":\"uuid.UUID\",\"Kind\":17,\"PkgPath\":\"github.com/gofrs/uuid\",\"Methods\":{\"Bytes\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}]},\"Format\":{\"In\":[{\"Name\":\"State\",\"Ident\":\"fmt.State\",\"Kind\":20,\"PkgPath\":\"fmt\",\"Methods\":null},{\"Name\":\"int32\",\"Ident\":\"int32\",\"Kind\":5,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"MarshalBinary\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"MarshalText\":{\"In\":[],\"Out\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Scan\":{\"In\":[{\"Name\":\"\",\"Ident\":\"interface {}\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"SetVariant\":{\"In\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"SetVersion\":{\"In\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[]},\"String\":{\"In\":[],\"Out\":[{\"Name\":\"string\",\"Ident\":\"string\",\"Kind\":24,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalBinary\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"UnmarshalText\":{\"In\":[{\"Name\":\"\",\"Ident\":\"[]uint8\",\"Kind\":23,\"PkgPath\":\"\",\"Methods\":null}],\"Out\":[{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Value\":{\"In\":[],\"Out\":[{\"Name\":\"Value\",\"Ident\":\"driver.Value\",\"Kind\":20,\"PkgPath\":\"database/sql/driver\",\"Methods\":null},{\"Name\":\"error\",\"Ident\":\"error\",\"Kind\":20,\"PkgPath\":\"\",\"Methods\":null}]},\"Variant\":{\"In\":[],\"Out\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}]},\"Version\":{\"In\":[],\"Out\":[{\"Name\":\"uint8\",\"Ident\":\"uint8\",\"Kind\":8,\"PkgPath\":\"\",\"Methods\":null}]}}}},\"optional\":true,\"immutable\":true,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"user_tasks\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]},{\"name\":\"User\",\"config\":{\"Table\":\"\"},\"edges\":[{\"name\":\"group\",\"type\":\"Group\",\"field\":\"group_users\",\"ref_name\":\"users\",\"unique\":true,\"inverse\":true,\"required\":true},{\"name\":\"files\",\"type\":\"File\"},{\"name\":\"dav_accounts\",\"type\":\"DavAccount\"},{\"name\":\"shares\",\"type\":\"Share\"},{\"name\":\"passkey\",\"type\":\"Passkey\"},{\"name\":\"tasks\",\"type\":\"Task\"},{\"name\":\"entities\",\"type\":\"Entity\"}],\"fields\":[{\"name\":\"created_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"immutable\":true,\"position\":{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"updated_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_kind\":19,\"update_default\":true,\"position\":{\"Index\":1,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"deleted_at\",\"type\":{\"Type\":2,\"Ident\":\"\",\"PkgPath\":\"time\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"nillable\":true,\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":true,\"MixinIndex\":0},\"schema_type\":{\"mysql\":\"datetime\"}},{\"name\":\"email\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":100,\"unique\":true,\"validators\":1,\"position\":{\"Index\":0,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"nick\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"size\":100,\"validators\":1,\"position\":{\"Index\":1,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"password\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":2,\"MixedIn\":false,\"MixinIndex\":0},\"sensitive\":true},{\"name\":\"status\",\"type\":{\"Type\":6,\"Ident\":\"user.Status\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"enums\":[{\"N\":\"active\",\"V\":\"active\"},{\"N\":\"inactive\",\"V\":\"inactive\"},{\"N\":\"manual_banned\",\"V\":\"manual_banned\"},{\"N\":\"sys_banned\",\"V\":\"sys_banned\"}],\"default\":true,\"default_value\":\"active\",\"default_kind\":24,\"position\":{\"Index\":3,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"storage\",\"type\":{\"Type\":13,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"default\":true,\"default_value\":0,\"default_kind\":6,\"position\":{\"Index\":4,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"two_factor_secret\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":5,\"MixedIn\":false,\"MixinIndex\":0},\"sensitive\":true},{\"name\":\"avatar\",\"type\":{\"Type\":7,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"optional\":true,\"position\":{\"Index\":6,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"settings\",\"type\":{\"Type\":3,\"Ident\":\"*types.UserSetting\",\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"PkgName\":\"types\",\"Nillable\":true,\"RType\":{\"Name\":\"UserSetting\",\"Ident\":\"types.UserSetting\",\"Kind\":22,\"PkgPath\":\"github.com/cloudreve/Cloudreve/v4/inventory/types\",\"Methods\":{}}},\"optional\":true,\"default\":true,\"default_value\":{},\"default_kind\":22,\"position\":{\"Index\":7,\"MixedIn\":false,\"MixinIndex\":0}},{\"name\":\"group_users\",\"type\":{\"Type\":12,\"Ident\":\"\",\"PkgPath\":\"\",\"PkgName\":\"\",\"Nillable\":false,\"RType\":null},\"position\":{\"Index\":8,\"MixedIn\":false,\"MixinIndex\":0}}],\"hooks\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}],\"interceptors\":[{\"Index\":0,\"MixedIn\":true,\"MixinIndex\":0}]}],\"Features\":[\"intercept\",\"schema/snapshot\",\"sql/upsert\",\"sql/upsert\",\"sql/execquery\"]}" diff --git a/ent/metadata.go b/ent/metadata.go new file mode 100644 index 00000000..d02a58c3 --- /dev/null +++ b/ent/metadata.go @@ -0,0 +1,214 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" +) + +// Metadata is the model entity for the Metadata schema. +type Metadata struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + // FileID holds the value of the "file_id" field. + FileID int `json:"file_id,omitempty"` + // IsPublic holds the value of the "is_public" field. + IsPublic bool `json:"is_public,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the MetadataQuery when eager-loading is set. + Edges MetadataEdges `json:"edges"` + selectValues sql.SelectValues +} + +// MetadataEdges holds the relations/edges for other nodes in the graph. +type MetadataEdges struct { + // File holds the value of the file edge. + File *File `json:"file,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// FileOrErr returns the File value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e MetadataEdges) FileOrErr() (*File, error) { + if e.loadedTypes[0] { + if e.File == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: file.Label} + } + return e.File, nil + } + return nil, &NotLoadedError{edge: "file"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Metadata) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case metadata.FieldIsPublic: + values[i] = new(sql.NullBool) + case metadata.FieldID, metadata.FieldFileID: + values[i] = new(sql.NullInt64) + case metadata.FieldName, metadata.FieldValue: + values[i] = new(sql.NullString) + case metadata.FieldCreatedAt, metadata.FieldUpdatedAt, metadata.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Metadata fields. +func (m *Metadata) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case metadata.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + m.ID = int(value.Int64) + case metadata.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + m.CreatedAt = value.Time + } + case metadata.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + m.UpdatedAt = value.Time + } + case metadata.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + m.DeletedAt = new(time.Time) + *m.DeletedAt = value.Time + } + case metadata.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + m.Name = value.String + } + case metadata.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + m.Value = value.String + } + case metadata.FieldFileID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field file_id", values[i]) + } else if value.Valid { + m.FileID = int(value.Int64) + } + case metadata.FieldIsPublic: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_public", values[i]) + } else if value.Valid { + m.IsPublic = value.Bool + } + default: + m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the Metadata. +// This includes values selected through modifiers, order, etc. +func (m *Metadata) GetValue(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + +// QueryFile queries the "file" edge of the Metadata entity. +func (m *Metadata) QueryFile() *FileQuery { + return NewMetadataClient(m.config).QueryFile(m) +} + +// Update returns a builder for updating this Metadata. +// Note that you need to call Metadata.Unwrap() before calling this method if this Metadata +// was returned from a transaction, and the transaction was committed or rolled back. +func (m *Metadata) Update() *MetadataUpdateOne { + return NewMetadataClient(m.config).UpdateOne(m) +} + +// Unwrap unwraps the Metadata entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (m *Metadata) Unwrap() *Metadata { + _tx, ok := m.config.driver.(*txDriver) + if !ok { + panic("ent: Metadata is not a transactional entity") + } + m.config.driver = _tx.drv + return m +} + +// String implements the fmt.Stringer. +func (m *Metadata) String() string { + var builder strings.Builder + builder.WriteString("Metadata(") + builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) + builder.WriteString("created_at=") + builder.WriteString(m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := m.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(m.Name) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(m.Value) + builder.WriteString(", ") + builder.WriteString("file_id=") + builder.WriteString(fmt.Sprintf("%v", m.FileID)) + builder.WriteString(", ") + builder.WriteString("is_public=") + builder.WriteString(fmt.Sprintf("%v", m.IsPublic)) + builder.WriteByte(')') + return builder.String() +} + +// SetFile manually set the edge as loaded state. +func (e *Metadata) SetFile(v *File) { + e.Edges.File = v + e.Edges.loadedTypes[0] = true +} + +// MetadataSlice is a parsable slice of Metadata. +type MetadataSlice []*Metadata diff --git a/ent/metadata/metadata.go b/ent/metadata/metadata.go new file mode 100644 index 00000000..83b02649 --- /dev/null +++ b/ent/metadata/metadata.go @@ -0,0 +1,140 @@ +// Code generated by ent, DO NOT EDIT. + +package metadata + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the metadata type in the database. + Label = "metadata" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // FieldFileID holds the string denoting the file_id field in the database. + FieldFileID = "file_id" + // FieldIsPublic holds the string denoting the is_public field in the database. + FieldIsPublic = "is_public" + // EdgeFile holds the string denoting the file edge name in mutations. + EdgeFile = "file" + // Table holds the table name of the metadata in the database. + Table = "metadata" + // FileTable is the table that holds the file relation/edge. + FileTable = "metadata" + // FileInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FileInverseTable = "files" + // FileColumn is the table column denoting the file relation/edge. + FileColumn = "file_id" +) + +// Columns holds all SQL columns for metadata fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldValue, + FieldFileID, + FieldIsPublic, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultIsPublic holds the default value on creation for the "is_public" field. + DefaultIsPublic bool +) + +// OrderOption defines the ordering options for the Metadata queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByFileID orders the results by the file_id field. +func ByFileID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFileID, opts...).ToFunc() +} + +// ByIsPublic orders the results by the is_public field. +func ByIsPublic(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsPublic, opts...).ToFunc() +} + +// ByFileField orders the results by file field. +func ByFileField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFileStep(), sql.OrderByField(field, opts...)) + } +} +func newFileStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FileInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) +} diff --git a/ent/metadata/where.go b/ent/metadata/where.go new file mode 100644 index 00000000..ad0fbec6 --- /dev/null +++ b/ent/metadata/where.go @@ -0,0 +1,419 @@ +// Code generated by ent, DO NOT EDIT. + +package metadata + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldName, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldValue, v)) +} + +// FileID applies equality check predicate on the "file_id" field. It's identical to FileIDEQ. +func FileID(v int) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldFileID, v)) +} + +// IsPublic applies equality check predicate on the "is_public" field. It's identical to IsPublicEQ. +func IsPublic(v bool) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldIsPublic, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Metadata { + return predicate.Metadata(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Metadata { + return predicate.Metadata(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldContainsFold(FieldName, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.Metadata { + return predicate.Metadata(sql.FieldContainsFold(FieldValue, v)) +} + +// FileIDEQ applies the EQ predicate on the "file_id" field. +func FileIDEQ(v int) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldFileID, v)) +} + +// FileIDNEQ applies the NEQ predicate on the "file_id" field. +func FileIDNEQ(v int) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldFileID, v)) +} + +// FileIDIn applies the In predicate on the "file_id" field. +func FileIDIn(vs ...int) predicate.Metadata { + return predicate.Metadata(sql.FieldIn(FieldFileID, vs...)) +} + +// FileIDNotIn applies the NotIn predicate on the "file_id" field. +func FileIDNotIn(vs ...int) predicate.Metadata { + return predicate.Metadata(sql.FieldNotIn(FieldFileID, vs...)) +} + +// IsPublicEQ applies the EQ predicate on the "is_public" field. +func IsPublicEQ(v bool) predicate.Metadata { + return predicate.Metadata(sql.FieldEQ(FieldIsPublic, v)) +} + +// IsPublicNEQ applies the NEQ predicate on the "is_public" field. +func IsPublicNEQ(v bool) predicate.Metadata { + return predicate.Metadata(sql.FieldNEQ(FieldIsPublic, v)) +} + +// HasFile applies the HasEdge predicate on the "file" edge. +func HasFile() predicate.Metadata { + return predicate.Metadata(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFileWith applies the HasEdge predicate on the "file" edge with a given conditions (other predicates). +func HasFileWith(preds ...predicate.File) predicate.Metadata { + return predicate.Metadata(func(s *sql.Selector) { + step := newFileStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Metadata) predicate.Metadata { + return predicate.Metadata(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Metadata) predicate.Metadata { + return predicate.Metadata(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Metadata) predicate.Metadata { + return predicate.Metadata(sql.NotPredicates(p)) +} diff --git a/ent/metadata_create.go b/ent/metadata_create.go new file mode 100644 index 00000000..10a774a1 --- /dev/null +++ b/ent/metadata_create.go @@ -0,0 +1,855 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" +) + +// MetadataCreate is the builder for creating a Metadata entity. +type MetadataCreate struct { + config + mutation *MetadataMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (mc *MetadataCreate) SetCreatedAt(t time.Time) *MetadataCreate { + mc.mutation.SetCreatedAt(t) + return mc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (mc *MetadataCreate) SetNillableCreatedAt(t *time.Time) *MetadataCreate { + if t != nil { + mc.SetCreatedAt(*t) + } + return mc +} + +// SetUpdatedAt sets the "updated_at" field. +func (mc *MetadataCreate) SetUpdatedAt(t time.Time) *MetadataCreate { + mc.mutation.SetUpdatedAt(t) + return mc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (mc *MetadataCreate) SetNillableUpdatedAt(t *time.Time) *MetadataCreate { + if t != nil { + mc.SetUpdatedAt(*t) + } + return mc +} + +// SetDeletedAt sets the "deleted_at" field. +func (mc *MetadataCreate) SetDeletedAt(t time.Time) *MetadataCreate { + mc.mutation.SetDeletedAt(t) + return mc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (mc *MetadataCreate) SetNillableDeletedAt(t *time.Time) *MetadataCreate { + if t != nil { + mc.SetDeletedAt(*t) + } + return mc +} + +// SetName sets the "name" field. +func (mc *MetadataCreate) SetName(s string) *MetadataCreate { + mc.mutation.SetName(s) + return mc +} + +// SetValue sets the "value" field. +func (mc *MetadataCreate) SetValue(s string) *MetadataCreate { + mc.mutation.SetValue(s) + return mc +} + +// SetFileID sets the "file_id" field. +func (mc *MetadataCreate) SetFileID(i int) *MetadataCreate { + mc.mutation.SetFileID(i) + return mc +} + +// SetIsPublic sets the "is_public" field. +func (mc *MetadataCreate) SetIsPublic(b bool) *MetadataCreate { + mc.mutation.SetIsPublic(b) + return mc +} + +// SetNillableIsPublic sets the "is_public" field if the given value is not nil. +func (mc *MetadataCreate) SetNillableIsPublic(b *bool) *MetadataCreate { + if b != nil { + mc.SetIsPublic(*b) + } + return mc +} + +// SetFile sets the "file" edge to the File entity. +func (mc *MetadataCreate) SetFile(f *File) *MetadataCreate { + return mc.SetFileID(f.ID) +} + +// Mutation returns the MetadataMutation object of the builder. +func (mc *MetadataCreate) Mutation() *MetadataMutation { + return mc.mutation +} + +// Save creates the Metadata in the database. +func (mc *MetadataCreate) Save(ctx context.Context) (*Metadata, error) { + if err := mc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (mc *MetadataCreate) SaveX(ctx context.Context) *Metadata { + v, err := mc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mc *MetadataCreate) Exec(ctx context.Context) error { + _, err := mc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mc *MetadataCreate) ExecX(ctx context.Context) { + if err := mc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (mc *MetadataCreate) defaults() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + if metadata.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized metadata.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := metadata.DefaultCreatedAt() + mc.mutation.SetCreatedAt(v) + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + if metadata.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized metadata.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := metadata.DefaultUpdatedAt() + mc.mutation.SetUpdatedAt(v) + } + if _, ok := mc.mutation.IsPublic(); !ok { + v := metadata.DefaultIsPublic + mc.mutation.SetIsPublic(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (mc *MetadataCreate) check() error { + if _, ok := mc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Metadata.created_at"`)} + } + if _, ok := mc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Metadata.updated_at"`)} + } + if _, ok := mc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Metadata.name"`)} + } + if _, ok := mc.mutation.Value(); !ok { + return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "Metadata.value"`)} + } + if _, ok := mc.mutation.FileID(); !ok { + return &ValidationError{Name: "file_id", err: errors.New(`ent: missing required field "Metadata.file_id"`)} + } + if _, ok := mc.mutation.IsPublic(); !ok { + return &ValidationError{Name: "is_public", err: errors.New(`ent: missing required field "Metadata.is_public"`)} + } + if _, ok := mc.mutation.FileID(); !ok { + return &ValidationError{Name: "file", err: errors.New(`ent: missing required edge "Metadata.file"`)} + } + return nil +} + +func (mc *MetadataCreate) sqlSave(ctx context.Context) (*Metadata, error) { + if err := mc.check(); err != nil { + return nil, err + } + _node, _spec := mc.createSpec() + if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + mc.mutation.id = &_node.ID + mc.mutation.done = true + return _node, nil +} + +func (mc *MetadataCreate) createSpec() (*Metadata, *sqlgraph.CreateSpec) { + var ( + _node = &Metadata{config: mc.config} + _spec = sqlgraph.NewCreateSpec(metadata.Table, sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt)) + ) + + if id, ok := mc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = mc.conflict + if value, ok := mc.mutation.CreatedAt(); ok { + _spec.SetField(metadata.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := mc.mutation.UpdatedAt(); ok { + _spec.SetField(metadata.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := mc.mutation.DeletedAt(); ok { + _spec.SetField(metadata.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := mc.mutation.Name(); ok { + _spec.SetField(metadata.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := mc.mutation.Value(); ok { + _spec.SetField(metadata.FieldValue, field.TypeString, value) + _node.Value = value + } + if value, ok := mc.mutation.IsPublic(); ok { + _spec.SetField(metadata.FieldIsPublic, field.TypeBool, value) + _node.IsPublic = value + } + if nodes := mc.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: metadata.FileTable, + Columns: []string{metadata.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.FileID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Metadata.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.MetadataUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (mc *MetadataCreate) OnConflict(opts ...sql.ConflictOption) *MetadataUpsertOne { + mc.conflict = opts + return &MetadataUpsertOne{ + create: mc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (mc *MetadataCreate) OnConflictColumns(columns ...string) *MetadataUpsertOne { + mc.conflict = append(mc.conflict, sql.ConflictColumns(columns...)) + return &MetadataUpsertOne{ + create: mc, + } +} + +type ( + // MetadataUpsertOne is the builder for "upsert"-ing + // one Metadata node. + MetadataUpsertOne struct { + create *MetadataCreate + } + + // MetadataUpsert is the "OnConflict" setter. + MetadataUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *MetadataUpsert) SetUpdatedAt(v time.Time) *MetadataUpsert { + u.Set(metadata.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateUpdatedAt() *MetadataUpsert { + u.SetExcluded(metadata.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *MetadataUpsert) SetDeletedAt(v time.Time) *MetadataUpsert { + u.Set(metadata.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateDeletedAt() *MetadataUpsert { + u.SetExcluded(metadata.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *MetadataUpsert) ClearDeletedAt() *MetadataUpsert { + u.SetNull(metadata.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *MetadataUpsert) SetName(v string) *MetadataUpsert { + u.Set(metadata.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateName() *MetadataUpsert { + u.SetExcluded(metadata.FieldName) + return u +} + +// SetValue sets the "value" field. +func (u *MetadataUpsert) SetValue(v string) *MetadataUpsert { + u.Set(metadata.FieldValue, v) + return u +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateValue() *MetadataUpsert { + u.SetExcluded(metadata.FieldValue) + return u +} + +// SetFileID sets the "file_id" field. +func (u *MetadataUpsert) SetFileID(v int) *MetadataUpsert { + u.Set(metadata.FieldFileID, v) + return u +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateFileID() *MetadataUpsert { + u.SetExcluded(metadata.FieldFileID) + return u +} + +// SetIsPublic sets the "is_public" field. +func (u *MetadataUpsert) SetIsPublic(v bool) *MetadataUpsert { + u.Set(metadata.FieldIsPublic, v) + return u +} + +// UpdateIsPublic sets the "is_public" field to the value that was provided on create. +func (u *MetadataUpsert) UpdateIsPublic() *MetadataUpsert { + u.SetExcluded(metadata.FieldIsPublic) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *MetadataUpsertOne) UpdateNewValues() *MetadataUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(metadata.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *MetadataUpsertOne) Ignore() *MetadataUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *MetadataUpsertOne) DoNothing() *MetadataUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the MetadataCreate.OnConflict +// documentation for more info. +func (u *MetadataUpsertOne) Update(set func(*MetadataUpsert)) *MetadataUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&MetadataUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *MetadataUpsertOne) SetUpdatedAt(v time.Time) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateUpdatedAt() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *MetadataUpsertOne) SetDeletedAt(v time.Time) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateDeletedAt() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *MetadataUpsertOne) ClearDeletedAt() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *MetadataUpsertOne) SetName(v string) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateName() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateName() + }) +} + +// SetValue sets the "value" field. +func (u *MetadataUpsertOne) SetValue(v string) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateValue() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateValue() + }) +} + +// SetFileID sets the "file_id" field. +func (u *MetadataUpsertOne) SetFileID(v int) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetFileID(v) + }) +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateFileID() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateFileID() + }) +} + +// SetIsPublic sets the "is_public" field. +func (u *MetadataUpsertOne) SetIsPublic(v bool) *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.SetIsPublic(v) + }) +} + +// UpdateIsPublic sets the "is_public" field to the value that was provided on create. +func (u *MetadataUpsertOne) UpdateIsPublic() *MetadataUpsertOne { + return u.Update(func(s *MetadataUpsert) { + s.UpdateIsPublic() + }) +} + +// Exec executes the query. +func (u *MetadataUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for MetadataCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *MetadataUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *MetadataUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *MetadataUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *MetadataCreate) SetRawID(t int) *MetadataCreate { + m.mutation.SetRawID(t) + return m +} + +// MetadataCreateBulk is the builder for creating many Metadata entities in bulk. +type MetadataCreateBulk struct { + config + err error + builders []*MetadataCreate + conflict []sql.ConflictOption +} + +// Save creates the Metadata entities in the database. +func (mcb *MetadataCreateBulk) Save(ctx context.Context) ([]*Metadata, error) { + if mcb.err != nil { + return nil, mcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) + nodes := make([]*Metadata, len(mcb.builders)) + mutators := make([]Mutator, len(mcb.builders)) + for i := range mcb.builders { + func(i int, root context.Context) { + builder := mcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*MetadataMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = mcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, mcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, mcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (mcb *MetadataCreateBulk) SaveX(ctx context.Context) []*Metadata { + v, err := mcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mcb *MetadataCreateBulk) Exec(ctx context.Context) error { + _, err := mcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mcb *MetadataCreateBulk) ExecX(ctx context.Context) { + if err := mcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Metadata.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.MetadataUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (mcb *MetadataCreateBulk) OnConflict(opts ...sql.ConflictOption) *MetadataUpsertBulk { + mcb.conflict = opts + return &MetadataUpsertBulk{ + create: mcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (mcb *MetadataCreateBulk) OnConflictColumns(columns ...string) *MetadataUpsertBulk { + mcb.conflict = append(mcb.conflict, sql.ConflictColumns(columns...)) + return &MetadataUpsertBulk{ + create: mcb, + } +} + +// MetadataUpsertBulk is the builder for "upsert"-ing +// a bulk of Metadata nodes. +type MetadataUpsertBulk struct { + create *MetadataCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *MetadataUpsertBulk) UpdateNewValues() *MetadataUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(metadata.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Metadata.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *MetadataUpsertBulk) Ignore() *MetadataUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *MetadataUpsertBulk) DoNothing() *MetadataUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the MetadataCreateBulk.OnConflict +// documentation for more info. +func (u *MetadataUpsertBulk) Update(set func(*MetadataUpsert)) *MetadataUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&MetadataUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *MetadataUpsertBulk) SetUpdatedAt(v time.Time) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateUpdatedAt() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *MetadataUpsertBulk) SetDeletedAt(v time.Time) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateDeletedAt() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *MetadataUpsertBulk) ClearDeletedAt() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *MetadataUpsertBulk) SetName(v string) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateName() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateName() + }) +} + +// SetValue sets the "value" field. +func (u *MetadataUpsertBulk) SetValue(v string) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateValue() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateValue() + }) +} + +// SetFileID sets the "file_id" field. +func (u *MetadataUpsertBulk) SetFileID(v int) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetFileID(v) + }) +} + +// UpdateFileID sets the "file_id" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateFileID() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateFileID() + }) +} + +// SetIsPublic sets the "is_public" field. +func (u *MetadataUpsertBulk) SetIsPublic(v bool) *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.SetIsPublic(v) + }) +} + +// UpdateIsPublic sets the "is_public" field to the value that was provided on create. +func (u *MetadataUpsertBulk) UpdateIsPublic() *MetadataUpsertBulk { + return u.Update(func(s *MetadataUpsert) { + s.UpdateIsPublic() + }) +} + +// Exec executes the query. +func (u *MetadataUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the MetadataCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for MetadataCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *MetadataUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/metadata_delete.go b/ent/metadata_delete.go new file mode 100644 index 00000000..cbccce87 --- /dev/null +++ b/ent/metadata_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// MetadataDelete is the builder for deleting a Metadata entity. +type MetadataDelete struct { + config + hooks []Hook + mutation *MetadataMutation +} + +// Where appends a list predicates to the MetadataDelete builder. +func (md *MetadataDelete) Where(ps ...predicate.Metadata) *MetadataDelete { + md.mutation.Where(ps...) + return md +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (md *MetadataDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (md *MetadataDelete) ExecX(ctx context.Context) int { + n, err := md.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (md *MetadataDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(metadata.Table, sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt)) + if ps := md.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, md.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + md.mutation.done = true + return affected, err +} + +// MetadataDeleteOne is the builder for deleting a single Metadata entity. +type MetadataDeleteOne struct { + md *MetadataDelete +} + +// Where appends a list predicates to the MetadataDelete builder. +func (mdo *MetadataDeleteOne) Where(ps ...predicate.Metadata) *MetadataDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + +// Exec executes the deletion query. +func (mdo *MetadataDeleteOne) Exec(ctx context.Context) error { + n, err := mdo.md.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{metadata.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (mdo *MetadataDeleteOne) ExecX(ctx context.Context) { + if err := mdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/metadata_query.go b/ent/metadata_query.go new file mode 100644 index 00000000..42c97c41 --- /dev/null +++ b/ent/metadata_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// MetadataQuery is the builder for querying Metadata entities. +type MetadataQuery struct { + config + ctx *QueryContext + order []metadata.OrderOption + inters []Interceptor + predicates []predicate.Metadata + withFile *FileQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the MetadataQuery builder. +func (mq *MetadataQuery) Where(ps ...predicate.Metadata) *MetadataQuery { + mq.predicates = append(mq.predicates, ps...) + return mq +} + +// Limit the number of records to be returned by this query. +func (mq *MetadataQuery) Limit(limit int) *MetadataQuery { + mq.ctx.Limit = &limit + return mq +} + +// Offset to start from. +func (mq *MetadataQuery) Offset(offset int) *MetadataQuery { + mq.ctx.Offset = &offset + return mq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (mq *MetadataQuery) Unique(unique bool) *MetadataQuery { + mq.ctx.Unique = &unique + return mq +} + +// Order specifies how the records should be ordered. +func (mq *MetadataQuery) Order(o ...metadata.OrderOption) *MetadataQuery { + mq.order = append(mq.order, o...) + return mq +} + +// QueryFile chains the current query on the "file" edge. +func (mq *MetadataQuery) QueryFile() *FileQuery { + query := (&FileClient{config: mq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := mq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(metadata.Table, metadata.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, metadata.FileTable, metadata.FileColumn), + ) + fromU = sqlgraph.SetNeighbors(mq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Metadata entity from the query. +// Returns a *NotFoundError when no Metadata was found. +func (mq *MetadataQuery) First(ctx context.Context) (*Metadata, error) { + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{metadata.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (mq *MetadataQuery) FirstX(ctx context.Context) *Metadata { + node, err := mq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Metadata ID from the query. +// Returns a *NotFoundError when no Metadata ID was found. +func (mq *MetadataQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{metadata.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (mq *MetadataQuery) FirstIDX(ctx context.Context) int { + id, err := mq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Metadata entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Metadata entity is found. +// Returns a *NotFoundError when no Metadata entities are found. +func (mq *MetadataQuery) Only(ctx context.Context) (*Metadata, error) { + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{metadata.Label} + default: + return nil, &NotSingularError{metadata.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (mq *MetadataQuery) OnlyX(ctx context.Context) *Metadata { + node, err := mq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Metadata ID in the query. +// Returns a *NotSingularError when more than one Metadata ID is found. +// Returns a *NotFoundError when no entities are found. +func (mq *MetadataQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{metadata.Label} + default: + err = &NotSingularError{metadata.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (mq *MetadataQuery) OnlyIDX(ctx context.Context) int { + id, err := mq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of MetadataSlice. +func (mq *MetadataQuery) All(ctx context.Context) ([]*Metadata, error) { + ctx = setContextOp(ctx, mq.ctx, "All") + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Metadata, *MetadataQuery]() + return withInterceptors[[]*Metadata](ctx, mq, qr, mq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (mq *MetadataQuery) AllX(ctx context.Context) []*Metadata { + nodes, err := mq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Metadata IDs. +func (mq *MetadataQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(metadata.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (mq *MetadataQuery) IDsX(ctx context.Context) []int { + ids, err := mq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (mq *MetadataQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") + if err := mq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, mq, querierCount[*MetadataQuery](), mq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (mq *MetadataQuery) CountX(ctx context.Context) int { + count, err := mq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (mq *MetadataQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (mq *MetadataQuery) ExistX(ctx context.Context) bool { + exist, err := mq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the MetadataQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (mq *MetadataQuery) Clone() *MetadataQuery { + if mq == nil { + return nil + } + return &MetadataQuery{ + config: mq.config, + ctx: mq.ctx.Clone(), + order: append([]metadata.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), + predicates: append([]predicate.Metadata{}, mq.predicates...), + withFile: mq.withFile.Clone(), + // clone intermediate query. + sql: mq.sql.Clone(), + path: mq.path, + } +} + +// WithFile tells the query-builder to eager-load the nodes that are connected to +// the "file" edge. The optional arguments are used to configure the query builder of the edge. +func (mq *MetadataQuery) WithFile(opts ...func(*FileQuery)) *MetadataQuery { + query := (&FileClient{config: mq.config}).Query() + for _, opt := range opts { + opt(query) + } + mq.withFile = query + return mq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Metadata.Query(). +// GroupBy(metadata.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (mq *MetadataQuery) GroupBy(field string, fields ...string) *MetadataGroupBy { + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MetadataGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields + grbuild.label = metadata.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Metadata.Query(). +// Select(metadata.FieldCreatedAt). +// Scan(ctx, &v) +func (mq *MetadataQuery) Select(fields ...string) *MetadataSelect { + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MetadataSelect{MetadataQuery: mq} + sbuild.label = metadata.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MetadataSelect configured with the given aggregations. +func (mq *MetadataQuery) Aggregate(fns ...AggregateFunc) *MetadataSelect { + return mq.Select().Aggregate(fns...) +} + +func (mq *MetadataQuery) prepareQuery(ctx context.Context) error { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { + if !metadata.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if mq.path != nil { + prev, err := mq.path(ctx) + if err != nil { + return err + } + mq.sql = prev + } + return nil +} + +func (mq *MetadataQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Metadata, error) { + var ( + nodes = []*Metadata{} + _spec = mq.querySpec() + loadedTypes = [1]bool{ + mq.withFile != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Metadata).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Metadata{config: mq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := mq.withFile; query != nil { + if err := mq.loadFile(ctx, query, nodes, nil, + func(n *Metadata, e *File) { n.Edges.File = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (mq *MetadataQuery) loadFile(ctx context.Context, query *FileQuery, nodes []*Metadata, init func(*Metadata), assign func(*Metadata, *File)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Metadata) + for i := range nodes { + fk := nodes[i].FileID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(file.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (mq *MetadataQuery) sqlCount(ctx context.Context) (int, error) { + _spec := mq.querySpec() + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, mq.driver, _spec) +} + +func (mq *MetadataQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(metadata.Table, metadata.Columns, sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true + } + if fields := mq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metadata.FieldID) + for i := range fields { + if fields[i] != metadata.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if mq.withFile != nil { + _spec.Node.AddColumnOnce(metadata.FieldFileID) + } + } + if ps := mq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := mq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := mq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := mq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (mq *MetadataQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(mq.driver.Dialect()) + t1 := builder.Table(metadata.Table) + columns := mq.ctx.Fields + if len(columns) == 0 { + columns = metadata.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if mq.sql != nil { + selector = mq.sql + selector.Select(selector.Columns(columns...)...) + } + if mq.ctx.Unique != nil && *mq.ctx.Unique { + selector.Distinct() + } + for _, p := range mq.predicates { + p(selector) + } + for _, p := range mq.order { + p(selector) + } + if offset := mq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := mq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// MetadataGroupBy is the group-by builder for Metadata entities. +type MetadataGroupBy struct { + selector + build *MetadataQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (mgb *MetadataGroupBy) Aggregate(fns ...AggregateFunc) *MetadataGroupBy { + mgb.fns = append(mgb.fns, fns...) + return mgb +} + +// Scan applies the selector query and scans the result into the given value. +func (mgb *MetadataGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetadataQuery, *MetadataGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) +} + +func (mgb *MetadataGroupBy) sqlScan(ctx context.Context, root *MetadataQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*mgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// MetadataSelect is the builder for selecting fields of Metadata entities. +type MetadataSelect struct { + *MetadataQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MetadataSelect) Aggregate(fns ...AggregateFunc) *MetadataSelect { + ms.fns = append(ms.fns, fns...) + return ms +} + +// Scan applies the selector query and scans the result into the given value. +func (ms *MetadataSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") + if err := ms.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MetadataQuery, *MetadataSelect](ctx, ms.MetadataQuery, ms, ms.inters, v) +} + +func (ms *MetadataSelect) sqlScan(ctx context.Context, root *MetadataQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ms.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/metadata_update.go b/ent/metadata_update.go new file mode 100644 index 00000000..3c82cd38 --- /dev/null +++ b/ent/metadata_update.go @@ -0,0 +1,509 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// MetadataUpdate is the builder for updating Metadata entities. +type MetadataUpdate struct { + config + hooks []Hook + mutation *MetadataMutation +} + +// Where appends a list predicates to the MetadataUpdate builder. +func (mu *MetadataUpdate) Where(ps ...predicate.Metadata) *MetadataUpdate { + mu.mutation.Where(ps...) + return mu +} + +// SetUpdatedAt sets the "updated_at" field. +func (mu *MetadataUpdate) SetUpdatedAt(t time.Time) *MetadataUpdate { + mu.mutation.SetUpdatedAt(t) + return mu +} + +// SetDeletedAt sets the "deleted_at" field. +func (mu *MetadataUpdate) SetDeletedAt(t time.Time) *MetadataUpdate { + mu.mutation.SetDeletedAt(t) + return mu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (mu *MetadataUpdate) SetNillableDeletedAt(t *time.Time) *MetadataUpdate { + if t != nil { + mu.SetDeletedAt(*t) + } + return mu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (mu *MetadataUpdate) ClearDeletedAt() *MetadataUpdate { + mu.mutation.ClearDeletedAt() + return mu +} + +// SetName sets the "name" field. +func (mu *MetadataUpdate) SetName(s string) *MetadataUpdate { + mu.mutation.SetName(s) + return mu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (mu *MetadataUpdate) SetNillableName(s *string) *MetadataUpdate { + if s != nil { + mu.SetName(*s) + } + return mu +} + +// SetValue sets the "value" field. +func (mu *MetadataUpdate) SetValue(s string) *MetadataUpdate { + mu.mutation.SetValue(s) + return mu +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (mu *MetadataUpdate) SetNillableValue(s *string) *MetadataUpdate { + if s != nil { + mu.SetValue(*s) + } + return mu +} + +// SetFileID sets the "file_id" field. +func (mu *MetadataUpdate) SetFileID(i int) *MetadataUpdate { + mu.mutation.SetFileID(i) + return mu +} + +// SetNillableFileID sets the "file_id" field if the given value is not nil. +func (mu *MetadataUpdate) SetNillableFileID(i *int) *MetadataUpdate { + if i != nil { + mu.SetFileID(*i) + } + return mu +} + +// SetIsPublic sets the "is_public" field. +func (mu *MetadataUpdate) SetIsPublic(b bool) *MetadataUpdate { + mu.mutation.SetIsPublic(b) + return mu +} + +// SetNillableIsPublic sets the "is_public" field if the given value is not nil. +func (mu *MetadataUpdate) SetNillableIsPublic(b *bool) *MetadataUpdate { + if b != nil { + mu.SetIsPublic(*b) + } + return mu +} + +// SetFile sets the "file" edge to the File entity. +func (mu *MetadataUpdate) SetFile(f *File) *MetadataUpdate { + return mu.SetFileID(f.ID) +} + +// Mutation returns the MetadataMutation object of the builder. +func (mu *MetadataUpdate) Mutation() *MetadataMutation { + return mu.mutation +} + +// ClearFile clears the "file" edge to the File entity. +func (mu *MetadataUpdate) ClearFile() *MetadataUpdate { + mu.mutation.ClearFile() + return mu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (mu *MetadataUpdate) Save(ctx context.Context) (int, error) { + if err := mu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mu *MetadataUpdate) SaveX(ctx context.Context) int { + affected, err := mu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (mu *MetadataUpdate) Exec(ctx context.Context) error { + _, err := mu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mu *MetadataUpdate) ExecX(ctx context.Context) { + if err := mu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (mu *MetadataUpdate) defaults() error { + if _, ok := mu.mutation.UpdatedAt(); !ok { + if metadata.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized metadata.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := metadata.UpdateDefaultUpdatedAt() + mu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (mu *MetadataUpdate) check() error { + if _, ok := mu.mutation.FileID(); mu.mutation.FileCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Metadata.file"`) + } + return nil +} + +func (mu *MetadataUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := mu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(metadata.Table, metadata.Columns, sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt)) + if ps := mu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mu.mutation.UpdatedAt(); ok { + _spec.SetField(metadata.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := mu.mutation.DeletedAt(); ok { + _spec.SetField(metadata.FieldDeletedAt, field.TypeTime, value) + } + if mu.mutation.DeletedAtCleared() { + _spec.ClearField(metadata.FieldDeletedAt, field.TypeTime) + } + if value, ok := mu.mutation.Name(); ok { + _spec.SetField(metadata.FieldName, field.TypeString, value) + } + if value, ok := mu.mutation.Value(); ok { + _spec.SetField(metadata.FieldValue, field.TypeString, value) + } + if value, ok := mu.mutation.IsPublic(); ok { + _spec.SetField(metadata.FieldIsPublic, field.TypeBool, value) + } + if mu.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: metadata.FileTable, + Columns: []string{metadata.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mu.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: metadata.FileTable, + Columns: []string{metadata.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, mu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metadata.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + mu.mutation.done = true + return n, nil +} + +// MetadataUpdateOne is the builder for updating a single Metadata entity. +type MetadataUpdateOne struct { + config + fields []string + hooks []Hook + mutation *MetadataMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (muo *MetadataUpdateOne) SetUpdatedAt(t time.Time) *MetadataUpdateOne { + muo.mutation.SetUpdatedAt(t) + return muo +} + +// SetDeletedAt sets the "deleted_at" field. +func (muo *MetadataUpdateOne) SetDeletedAt(t time.Time) *MetadataUpdateOne { + muo.mutation.SetDeletedAt(t) + return muo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (muo *MetadataUpdateOne) SetNillableDeletedAt(t *time.Time) *MetadataUpdateOne { + if t != nil { + muo.SetDeletedAt(*t) + } + return muo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (muo *MetadataUpdateOne) ClearDeletedAt() *MetadataUpdateOne { + muo.mutation.ClearDeletedAt() + return muo +} + +// SetName sets the "name" field. +func (muo *MetadataUpdateOne) SetName(s string) *MetadataUpdateOne { + muo.mutation.SetName(s) + return muo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (muo *MetadataUpdateOne) SetNillableName(s *string) *MetadataUpdateOne { + if s != nil { + muo.SetName(*s) + } + return muo +} + +// SetValue sets the "value" field. +func (muo *MetadataUpdateOne) SetValue(s string) *MetadataUpdateOne { + muo.mutation.SetValue(s) + return muo +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (muo *MetadataUpdateOne) SetNillableValue(s *string) *MetadataUpdateOne { + if s != nil { + muo.SetValue(*s) + } + return muo +} + +// SetFileID sets the "file_id" field. +func (muo *MetadataUpdateOne) SetFileID(i int) *MetadataUpdateOne { + muo.mutation.SetFileID(i) + return muo +} + +// SetNillableFileID sets the "file_id" field if the given value is not nil. +func (muo *MetadataUpdateOne) SetNillableFileID(i *int) *MetadataUpdateOne { + if i != nil { + muo.SetFileID(*i) + } + return muo +} + +// SetIsPublic sets the "is_public" field. +func (muo *MetadataUpdateOne) SetIsPublic(b bool) *MetadataUpdateOne { + muo.mutation.SetIsPublic(b) + return muo +} + +// SetNillableIsPublic sets the "is_public" field if the given value is not nil. +func (muo *MetadataUpdateOne) SetNillableIsPublic(b *bool) *MetadataUpdateOne { + if b != nil { + muo.SetIsPublic(*b) + } + return muo +} + +// SetFile sets the "file" edge to the File entity. +func (muo *MetadataUpdateOne) SetFile(f *File) *MetadataUpdateOne { + return muo.SetFileID(f.ID) +} + +// Mutation returns the MetadataMutation object of the builder. +func (muo *MetadataUpdateOne) Mutation() *MetadataMutation { + return muo.mutation +} + +// ClearFile clears the "file" edge to the File entity. +func (muo *MetadataUpdateOne) ClearFile() *MetadataUpdateOne { + muo.mutation.ClearFile() + return muo +} + +// Where appends a list predicates to the MetadataUpdate builder. +func (muo *MetadataUpdateOne) Where(ps ...predicate.Metadata) *MetadataUpdateOne { + muo.mutation.Where(ps...) + return muo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (muo *MetadataUpdateOne) Select(field string, fields ...string) *MetadataUpdateOne { + muo.fields = append([]string{field}, fields...) + return muo +} + +// Save executes the query and returns the updated Metadata entity. +func (muo *MetadataUpdateOne) Save(ctx context.Context) (*Metadata, error) { + if err := muo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (muo *MetadataUpdateOne) SaveX(ctx context.Context) *Metadata { + node, err := muo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (muo *MetadataUpdateOne) Exec(ctx context.Context) error { + _, err := muo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (muo *MetadataUpdateOne) ExecX(ctx context.Context) { + if err := muo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (muo *MetadataUpdateOne) defaults() error { + if _, ok := muo.mutation.UpdatedAt(); !ok { + if metadata.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized metadata.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := metadata.UpdateDefaultUpdatedAt() + muo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (muo *MetadataUpdateOne) check() error { + if _, ok := muo.mutation.FileID(); muo.mutation.FileCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Metadata.file"`) + } + return nil +} + +func (muo *MetadataUpdateOne) sqlSave(ctx context.Context) (_node *Metadata, err error) { + if err := muo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(metadata.Table, metadata.Columns, sqlgraph.NewFieldSpec(metadata.FieldID, field.TypeInt)) + id, ok := muo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Metadata.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := muo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, metadata.FieldID) + for _, f := range fields { + if !metadata.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != metadata.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := muo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := muo.mutation.UpdatedAt(); ok { + _spec.SetField(metadata.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := muo.mutation.DeletedAt(); ok { + _spec.SetField(metadata.FieldDeletedAt, field.TypeTime, value) + } + if muo.mutation.DeletedAtCleared() { + _spec.ClearField(metadata.FieldDeletedAt, field.TypeTime) + } + if value, ok := muo.mutation.Name(); ok { + _spec.SetField(metadata.FieldName, field.TypeString, value) + } + if value, ok := muo.mutation.Value(); ok { + _spec.SetField(metadata.FieldValue, field.TypeString, value) + } + if value, ok := muo.mutation.IsPublic(); ok { + _spec.SetField(metadata.FieldIsPublic, field.TypeBool, value) + } + if muo.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: metadata.FileTable, + Columns: []string{metadata.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := muo.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: metadata.FileTable, + Columns: []string{metadata.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Metadata{config: muo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, muo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{metadata.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + muo.mutation.done = true + return _node, nil +} diff --git a/ent/migrate/migrate.go b/ent/migrate/migrate.go new file mode 100644 index 00000000..1956a6bf --- /dev/null +++ b/ent/migrate/migrate.go @@ -0,0 +1,64 @@ +// Code generated by ent, DO NOT EDIT. + +package migrate + +import ( + "context" + "fmt" + "io" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql/schema" +) + +var ( + // WithGlobalUniqueID sets the universal ids options to the migration. + // If this option is enabled, ent migration will allocate a 1<<32 range + // for the ids of each entity (table). + // Note that this option cannot be applied on tables that already exist. + WithGlobalUniqueID = schema.WithGlobalUniqueID + // WithDropColumn sets the drop column option to the migration. + // If this option is enabled, ent migration will drop old columns + // that were used for both fields and edges. This defaults to false. + WithDropColumn = schema.WithDropColumn + // WithDropIndex sets the drop index option to the migration. + // If this option is enabled, ent migration will drop old indexes + // that were defined in the schema. This defaults to false. + // Note that unique constraints are defined using `UNIQUE INDEX`, + // and therefore, it's recommended to enable this option to get more + // flexibility in the schema changes. + WithDropIndex = schema.WithDropIndex + // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. + WithForeignKeys = schema.WithForeignKeys +) + +// Schema is the API for creating, migrating and dropping a schema. +type Schema struct { + drv dialect.Driver +} + +// NewSchema creates a new schema client. +func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } + +// Create creates all schema resources. +func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { + return Create(ctx, s, Tables, opts...) +} + +// Create creates all table resources using the given schema driver. +func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error { + migrate, err := schema.NewMigrate(s.drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %w", err) + } + return migrate.Create(ctx, tables...) +} + +// WriteTo writes the schema changes to w instead of running them against the database. +// +// if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { +// log.Fatal(err) +// } +func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { + return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...) +} diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go new file mode 100644 index 00000000..e7c8b68e --- /dev/null +++ b/ent/migrate/schema.go @@ -0,0 +1,486 @@ +// Code generated by ent, DO NOT EDIT. + +package migrate + +import ( + "entgo.io/ent/dialect/sql/schema" + "entgo.io/ent/schema/field" +) + +var ( + // DavAccountsColumns holds the columns for the "dav_accounts" table. + DavAccountsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString}, + {Name: "uri", Type: field.TypeString, Size: 2147483647}, + {Name: "password", Type: field.TypeString}, + {Name: "options", Type: field.TypeBytes}, + {Name: "props", Type: field.TypeJSON, Nullable: true}, + {Name: "owner_id", Type: field.TypeInt}, + } + // DavAccountsTable holds the schema information for the "dav_accounts" table. + DavAccountsTable = &schema.Table{ + Name: "dav_accounts", + Columns: DavAccountsColumns, + PrimaryKey: []*schema.Column{DavAccountsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "dav_accounts_users_dav_accounts", + Columns: []*schema.Column{DavAccountsColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "davaccount_owner_id_password", + Unique: true, + Columns: []*schema.Column{DavAccountsColumns[9], DavAccountsColumns[6]}, + }, + }, + } + // DirectLinksColumns holds the columns for the "direct_links" table. + DirectLinksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString}, + {Name: "downloads", Type: field.TypeInt}, + {Name: "speed", Type: field.TypeInt}, + {Name: "file_id", Type: field.TypeInt}, + } + // DirectLinksTable holds the schema information for the "direct_links" table. + DirectLinksTable = &schema.Table{ + Name: "direct_links", + Columns: DirectLinksColumns, + PrimaryKey: []*schema.Column{DirectLinksColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "direct_links_files_direct_links", + Columns: []*schema.Column{DirectLinksColumns[7]}, + RefColumns: []*schema.Column{FilesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + } + // EntitiesColumns holds the columns for the "entities" table. + EntitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "type", Type: field.TypeInt}, + {Name: "source", Type: field.TypeString, Size: 2147483647}, + {Name: "size", Type: field.TypeInt64}, + {Name: "reference_count", Type: field.TypeInt, Default: 1}, + {Name: "upload_session_id", Type: field.TypeUUID, Nullable: true}, + {Name: "recycle_options", Type: field.TypeJSON, Nullable: true}, + {Name: "storage_policy_entities", Type: field.TypeInt}, + {Name: "created_by", Type: field.TypeInt, Nullable: true}, + } + // EntitiesTable holds the schema information for the "entities" table. + EntitiesTable = &schema.Table{ + Name: "entities", + Columns: EntitiesColumns, + PrimaryKey: []*schema.Column{EntitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "entities_storage_policies_entities", + Columns: []*schema.Column{EntitiesColumns[10]}, + RefColumns: []*schema.Column{StoragePoliciesColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "entities_users_entities", + Columns: []*schema.Column{EntitiesColumns[11]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // FilesColumns holds the columns for the "files" table. + FilesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "type", Type: field.TypeInt}, + {Name: "name", Type: field.TypeString}, + {Name: "size", Type: field.TypeInt64, Default: 0}, + {Name: "primary_entity", Type: field.TypeInt, Nullable: true}, + {Name: "is_symbolic", Type: field.TypeBool, Default: false}, + {Name: "props", Type: field.TypeJSON, Nullable: true}, + {Name: "file_children", Type: field.TypeInt, Nullable: true}, + {Name: "storage_policy_files", Type: field.TypeInt, Nullable: true}, + {Name: "owner_id", Type: field.TypeInt}, + } + // FilesTable holds the schema information for the "files" table. + FilesTable = &schema.Table{ + Name: "files", + Columns: FilesColumns, + PrimaryKey: []*schema.Column{FilesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "files_files_children", + Columns: []*schema.Column{FilesColumns[10]}, + RefColumns: []*schema.Column{FilesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "files_storage_policies_files", + Columns: []*schema.Column{FilesColumns[11]}, + RefColumns: []*schema.Column{StoragePoliciesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "files_users_files", + Columns: []*schema.Column{FilesColumns[12]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "file_file_children_name", + Unique: true, + Columns: []*schema.Column{FilesColumns[10], FilesColumns[5]}, + }, + { + Name: "file_file_children_type_updated_at", + Unique: false, + Columns: []*schema.Column{FilesColumns[10], FilesColumns[4], FilesColumns[2]}, + }, + { + Name: "file_file_children_type_size", + Unique: false, + Columns: []*schema.Column{FilesColumns[10], FilesColumns[4], FilesColumns[6]}, + }, + }, + } + // GroupsColumns holds the columns for the "groups" table. + GroupsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString}, + {Name: "max_storage", Type: field.TypeInt64, Nullable: true}, + {Name: "speed_limit", Type: field.TypeInt, Nullable: true}, + {Name: "permissions", Type: field.TypeBytes}, + {Name: "settings", Type: field.TypeJSON, Nullable: true}, + {Name: "storage_policy_id", Type: field.TypeInt, Nullable: true}, + } + // GroupsTable holds the schema information for the "groups" table. + GroupsTable = &schema.Table{ + Name: "groups", + Columns: GroupsColumns, + PrimaryKey: []*schema.Column{GroupsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "groups_storage_policies_groups", + Columns: []*schema.Column{GroupsColumns[9]}, + RefColumns: []*schema.Column{StoragePoliciesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // MetadataColumns holds the columns for the "metadata" table. + MetadataColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString}, + {Name: "value", Type: field.TypeString, Size: 2147483647}, + {Name: "is_public", Type: field.TypeBool, Default: false}, + {Name: "file_id", Type: field.TypeInt}, + } + // MetadataTable holds the schema information for the "metadata" table. + MetadataTable = &schema.Table{ + Name: "metadata", + Columns: MetadataColumns, + PrimaryKey: []*schema.Column{MetadataColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "metadata_files_metadata", + Columns: []*schema.Column{MetadataColumns[7]}, + RefColumns: []*schema.Column{FilesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "metadata_file_id_name", + Unique: true, + Columns: []*schema.Column{MetadataColumns[7], MetadataColumns[4]}, + }, + }, + } + // NodesColumns holds the columns for the "nodes" table. + NodesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "suspended"}}, + {Name: "name", Type: field.TypeString}, + {Name: "type", Type: field.TypeEnum, Enums: []string{"master", "slave"}}, + {Name: "server", Type: field.TypeString, Nullable: true}, + {Name: "slave_key", Type: field.TypeString, Nullable: true}, + {Name: "capabilities", Type: field.TypeBytes}, + {Name: "settings", Type: field.TypeJSON, Nullable: true}, + {Name: "weight", Type: field.TypeInt, Default: 0}, + } + // NodesTable holds the schema information for the "nodes" table. + NodesTable = &schema.Table{ + Name: "nodes", + Columns: NodesColumns, + PrimaryKey: []*schema.Column{NodesColumns[0]}, + } + // PasskeysColumns holds the columns for the "passkeys" table. + PasskeysColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "credential_id", Type: field.TypeString}, + {Name: "name", Type: field.TypeString}, + {Name: "credential", Type: field.TypeJSON}, + {Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "user_id", Type: field.TypeInt}, + } + // PasskeysTable holds the schema information for the "passkeys" table. + PasskeysTable = &schema.Table{ + Name: "passkeys", + Columns: PasskeysColumns, + PrimaryKey: []*schema.Column{PasskeysColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "passkeys_users_passkey", + Columns: []*schema.Column{PasskeysColumns[8]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "passkey_user_id_credential_id", + Unique: true, + Columns: []*schema.Column{PasskeysColumns[8], PasskeysColumns[4]}, + }, + }, + } + // SettingsColumns holds the columns for the "settings" table. + SettingsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString, Unique: true}, + {Name: "value", Type: field.TypeString, Nullable: true, Size: 2147483647}, + } + // SettingsTable holds the schema information for the "settings" table. + SettingsTable = &schema.Table{ + Name: "settings", + Columns: SettingsColumns, + PrimaryKey: []*schema.Column{SettingsColumns[0]}, + } + // SharesColumns holds the columns for the "shares" table. + SharesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "password", Type: field.TypeString, Nullable: true}, + {Name: "views", Type: field.TypeInt, Default: 0}, + {Name: "downloads", Type: field.TypeInt, Default: 0}, + {Name: "expires", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "remain_downloads", Type: field.TypeInt, Nullable: true}, + {Name: "file_shares", Type: field.TypeInt, Nullable: true}, + {Name: "user_shares", Type: field.TypeInt, Nullable: true}, + } + // SharesTable holds the schema information for the "shares" table. + SharesTable = &schema.Table{ + Name: "shares", + Columns: SharesColumns, + PrimaryKey: []*schema.Column{SharesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "shares_files_shares", + Columns: []*schema.Column{SharesColumns[9]}, + RefColumns: []*schema.Column{FilesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "shares_users_shares", + Columns: []*schema.Column{SharesColumns[10]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // StoragePoliciesColumns holds the columns for the "storage_policies" table. + StoragePoliciesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "name", Type: field.TypeString}, + {Name: "type", Type: field.TypeString}, + {Name: "server", Type: field.TypeString, Nullable: true}, + {Name: "bucket_name", Type: field.TypeString, Nullable: true}, + {Name: "is_private", Type: field.TypeBool, Nullable: true}, + {Name: "access_key", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "secret_key", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "max_size", Type: field.TypeInt64, Nullable: true}, + {Name: "dir_name_rule", Type: field.TypeString, Nullable: true}, + {Name: "file_name_rule", Type: field.TypeString, Nullable: true}, + {Name: "settings", Type: field.TypeJSON, Nullable: true}, + {Name: "node_id", Type: field.TypeInt, Nullable: true}, + } + // StoragePoliciesTable holds the schema information for the "storage_policies" table. + StoragePoliciesTable = &schema.Table{ + Name: "storage_policies", + Columns: StoragePoliciesColumns, + PrimaryKey: []*schema.Column{StoragePoliciesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "storage_policies_nodes_storage_policy", + Columns: []*schema.Column{StoragePoliciesColumns[15]}, + RefColumns: []*schema.Column{NodesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // TasksColumns holds the columns for the "tasks" table. + TasksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "type", Type: field.TypeString}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"queued", "processing", "suspending", "error", "canceled", "completed"}, Default: "queued"}, + {Name: "public_state", Type: field.TypeJSON}, + {Name: "private_state", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "correlation_id", Type: field.TypeUUID, Nullable: true}, + {Name: "user_tasks", Type: field.TypeInt, Nullable: true}, + } + // TasksTable holds the schema information for the "tasks" table. + TasksTable = &schema.Table{ + Name: "tasks", + Columns: TasksColumns, + PrimaryKey: []*schema.Column{TasksColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "tasks_users_tasks", + Columns: []*schema.Column{TasksColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // UsersColumns holds the columns for the "users" table. + UsersColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime"}}, + {Name: "email", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "nick", Type: field.TypeString, Size: 100}, + {Name: "password", Type: field.TypeString, Nullable: true}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "manual_banned", "sys_banned"}, Default: "active"}, + {Name: "storage", Type: field.TypeInt64, Default: 0}, + {Name: "two_factor_secret", Type: field.TypeString, Nullable: true}, + {Name: "avatar", Type: field.TypeString, Nullable: true}, + {Name: "settings", Type: field.TypeJSON, Nullable: true}, + {Name: "group_users", Type: field.TypeInt}, + {Name: "storage_policy_users", Type: field.TypeInt, Nullable: true}, + } + // UsersTable holds the schema information for the "users" table. + UsersTable = &schema.Table{ + Name: "users", + Columns: UsersColumns, + PrimaryKey: []*schema.Column{UsersColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "users_groups_users", + Columns: []*schema.Column{UsersColumns[12]}, + RefColumns: []*schema.Column{GroupsColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "users_storage_policies_users", + Columns: []*schema.Column{UsersColumns[13]}, + RefColumns: []*schema.Column{StoragePoliciesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + } + // FileEntitiesColumns holds the columns for the "file_entities" table. + FileEntitiesColumns = []*schema.Column{ + {Name: "file_id", Type: field.TypeInt}, + {Name: "entity_id", Type: field.TypeInt}, + } + // FileEntitiesTable holds the schema information for the "file_entities" table. + FileEntitiesTable = &schema.Table{ + Name: "file_entities", + Columns: FileEntitiesColumns, + PrimaryKey: []*schema.Column{FileEntitiesColumns[0], FileEntitiesColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "file_entities_file_id", + Columns: []*schema.Column{FileEntitiesColumns[0]}, + RefColumns: []*schema.Column{FilesColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "file_entities_entity_id", + Columns: []*schema.Column{FileEntitiesColumns[1]}, + RefColumns: []*schema.Column{EntitiesColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } + // Tables holds all the tables in the schema. + Tables = []*schema.Table{ + DavAccountsTable, + DirectLinksTable, + EntitiesTable, + FilesTable, + GroupsTable, + MetadataTable, + NodesTable, + PasskeysTable, + SettingsTable, + SharesTable, + StoragePoliciesTable, + TasksTable, + UsersTable, + FileEntitiesTable, + } +) + +func init() { + DavAccountsTable.ForeignKeys[0].RefTable = UsersTable + DirectLinksTable.ForeignKeys[0].RefTable = FilesTable + EntitiesTable.ForeignKeys[0].RefTable = StoragePoliciesTable + EntitiesTable.ForeignKeys[1].RefTable = UsersTable + FilesTable.ForeignKeys[0].RefTable = FilesTable + FilesTable.ForeignKeys[1].RefTable = StoragePoliciesTable + FilesTable.ForeignKeys[2].RefTable = UsersTable + GroupsTable.ForeignKeys[0].RefTable = StoragePoliciesTable + MetadataTable.ForeignKeys[0].RefTable = FilesTable + PasskeysTable.ForeignKeys[0].RefTable = UsersTable + SharesTable.ForeignKeys[0].RefTable = FilesTable + SharesTable.ForeignKeys[1].RefTable = UsersTable + StoragePoliciesTable.ForeignKeys[0].RefTable = NodesTable + TasksTable.ForeignKeys[0].RefTable = UsersTable + UsersTable.ForeignKeys[0].RefTable = GroupsTable + UsersTable.ForeignKeys[1].RefTable = StoragePoliciesTable + FileEntitiesTable.ForeignKeys[0].RefTable = FilesTable + FileEntitiesTable.ForeignKeys[1].RefTable = EntitiesTable +} diff --git a/ent/mutation.go b/ent/mutation.go new file mode 100644 index 00000000..0cb5ac9e --- /dev/null +++ b/ent/mutation.go @@ -0,0 +1,14201 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" +) + +const ( + // Operation types. + OpCreate = ent.OpCreate + OpDelete = ent.OpDelete + OpDeleteOne = ent.OpDeleteOne + OpUpdate = ent.OpUpdate + OpUpdateOne = ent.OpUpdateOne + + // Node types. + TypeDavAccount = "DavAccount" + TypeDirectLink = "DirectLink" + TypeEntity = "Entity" + TypeFile = "File" + TypeGroup = "Group" + TypeMetadata = "Metadata" + TypeNode = "Node" + TypePasskey = "Passkey" + TypeSetting = "Setting" + TypeShare = "Share" + TypeStoragePolicy = "StoragePolicy" + TypeTask = "Task" + TypeUser = "User" +) + +// DavAccountMutation represents an operation that mutates the DavAccount nodes in the graph. +type DavAccountMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + uri *string + password *string + options **boolset.BooleanSet + props **types.DavAccountProps + clearedFields map[string]struct{} + owner *int + clearedowner bool + done bool + oldValue func(context.Context) (*DavAccount, error) + predicates []predicate.DavAccount +} + +var _ ent.Mutation = (*DavAccountMutation)(nil) + +// davaccountOption allows management of the mutation configuration using functional options. +type davaccountOption func(*DavAccountMutation) + +// newDavAccountMutation creates new mutation for the DavAccount entity. +func newDavAccountMutation(c config, op Op, opts ...davaccountOption) *DavAccountMutation { + m := &DavAccountMutation{ + config: c, + op: op, + typ: TypeDavAccount, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withDavAccountID sets the ID field of the mutation. +func withDavAccountID(id int) davaccountOption { + return func(m *DavAccountMutation) { + var ( + err error + once sync.Once + value *DavAccount + ) + m.oldValue = func(ctx context.Context) (*DavAccount, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().DavAccount.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withDavAccount sets the old DavAccount of the mutation. +func withDavAccount(node *DavAccount) davaccountOption { + return func(m *DavAccountMutation) { + m.oldValue = func(context.Context) (*DavAccount, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m DavAccountMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m DavAccountMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *DavAccountMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *DavAccountMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().DavAccount.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *DavAccountMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *DavAccountMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *DavAccountMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *DavAccountMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *DavAccountMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *DavAccountMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *DavAccountMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *DavAccountMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *DavAccountMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[davaccount.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *DavAccountMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[davaccount.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *DavAccountMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, davaccount.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *DavAccountMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *DavAccountMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *DavAccountMutation) ResetName() { + m.name = nil +} + +// SetURI sets the "uri" field. +func (m *DavAccountMutation) SetURI(s string) { + m.uri = &s +} + +// URI returns the value of the "uri" field in the mutation. +func (m *DavAccountMutation) URI() (r string, exists bool) { + v := m.uri + if v == nil { + return + } + return *v, true +} + +// OldURI returns the old "uri" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldURI(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldURI is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldURI requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldURI: %w", err) + } + return oldValue.URI, nil +} + +// ResetURI resets all changes to the "uri" field. +func (m *DavAccountMutation) ResetURI() { + m.uri = nil +} + +// SetPassword sets the "password" field. +func (m *DavAccountMutation) SetPassword(s string) { + m.password = &s +} + +// Password returns the value of the "password" field in the mutation. +func (m *DavAccountMutation) Password() (r string, exists bool) { + v := m.password + if v == nil { + return + } + return *v, true +} + +// OldPassword returns the old "password" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldPassword(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassword is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassword requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassword: %w", err) + } + return oldValue.Password, nil +} + +// ResetPassword resets all changes to the "password" field. +func (m *DavAccountMutation) ResetPassword() { + m.password = nil +} + +// SetOptions sets the "options" field. +func (m *DavAccountMutation) SetOptions(bs *boolset.BooleanSet) { + m.options = &bs +} + +// Options returns the value of the "options" field in the mutation. +func (m *DavAccountMutation) Options() (r *boolset.BooleanSet, exists bool) { + v := m.options + if v == nil { + return + } + return *v, true +} + +// OldOptions returns the old "options" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldOptions(ctx context.Context) (v *boolset.BooleanSet, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOptions is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOptions requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOptions: %w", err) + } + return oldValue.Options, nil +} + +// ResetOptions resets all changes to the "options" field. +func (m *DavAccountMutation) ResetOptions() { + m.options = nil +} + +// SetProps sets the "props" field. +func (m *DavAccountMutation) SetProps(tap *types.DavAccountProps) { + m.props = &tap +} + +// Props returns the value of the "props" field in the mutation. +func (m *DavAccountMutation) Props() (r *types.DavAccountProps, exists bool) { + v := m.props + if v == nil { + return + } + return *v, true +} + +// OldProps returns the old "props" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldProps(ctx context.Context) (v *types.DavAccountProps, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProps is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProps requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProps: %w", err) + } + return oldValue.Props, nil +} + +// ClearProps clears the value of the "props" field. +func (m *DavAccountMutation) ClearProps() { + m.props = nil + m.clearedFields[davaccount.FieldProps] = struct{}{} +} + +// PropsCleared returns if the "props" field was cleared in this mutation. +func (m *DavAccountMutation) PropsCleared() bool { + _, ok := m.clearedFields[davaccount.FieldProps] + return ok +} + +// ResetProps resets all changes to the "props" field. +func (m *DavAccountMutation) ResetProps() { + m.props = nil + delete(m.clearedFields, davaccount.FieldProps) +} + +// SetOwnerID sets the "owner_id" field. +func (m *DavAccountMutation) SetOwnerID(i int) { + m.owner = &i +} + +// OwnerID returns the value of the "owner_id" field in the mutation. +func (m *DavAccountMutation) OwnerID() (r int, exists bool) { + v := m.owner + if v == nil { + return + } + return *v, true +} + +// OldOwnerID returns the old "owner_id" field's value of the DavAccount entity. +// If the DavAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DavAccountMutation) OldOwnerID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOwnerID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOwnerID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOwnerID: %w", err) + } + return oldValue.OwnerID, nil +} + +// ResetOwnerID resets all changes to the "owner_id" field. +func (m *DavAccountMutation) ResetOwnerID() { + m.owner = nil +} + +// ClearOwner clears the "owner" edge to the User entity. +func (m *DavAccountMutation) ClearOwner() { + m.clearedowner = true + m.clearedFields[davaccount.FieldOwnerID] = struct{}{} +} + +// OwnerCleared reports if the "owner" edge to the User entity was cleared. +func (m *DavAccountMutation) OwnerCleared() bool { + return m.clearedowner +} + +// OwnerIDs returns the "owner" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// OwnerID instead. It exists only for internal usage by the builders. +func (m *DavAccountMutation) OwnerIDs() (ids []int) { + if id := m.owner; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetOwner resets all changes to the "owner" edge. +func (m *DavAccountMutation) ResetOwner() { + m.owner = nil + m.clearedowner = false +} + +// Where appends a list predicates to the DavAccountMutation builder. +func (m *DavAccountMutation) Where(ps ...predicate.DavAccount) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the DavAccountMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DavAccountMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.DavAccount, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *DavAccountMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *DavAccountMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (DavAccount). +func (m *DavAccountMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *DavAccountMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, davaccount.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, davaccount.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, davaccount.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, davaccount.FieldName) + } + if m.uri != nil { + fields = append(fields, davaccount.FieldURI) + } + if m.password != nil { + fields = append(fields, davaccount.FieldPassword) + } + if m.options != nil { + fields = append(fields, davaccount.FieldOptions) + } + if m.props != nil { + fields = append(fields, davaccount.FieldProps) + } + if m.owner != nil { + fields = append(fields, davaccount.FieldOwnerID) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *DavAccountMutation) Field(name string) (ent.Value, bool) { + switch name { + case davaccount.FieldCreatedAt: + return m.CreatedAt() + case davaccount.FieldUpdatedAt: + return m.UpdatedAt() + case davaccount.FieldDeletedAt: + return m.DeletedAt() + case davaccount.FieldName: + return m.Name() + case davaccount.FieldURI: + return m.URI() + case davaccount.FieldPassword: + return m.Password() + case davaccount.FieldOptions: + return m.Options() + case davaccount.FieldProps: + return m.Props() + case davaccount.FieldOwnerID: + return m.OwnerID() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *DavAccountMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case davaccount.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case davaccount.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case davaccount.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case davaccount.FieldName: + return m.OldName(ctx) + case davaccount.FieldURI: + return m.OldURI(ctx) + case davaccount.FieldPassword: + return m.OldPassword(ctx) + case davaccount.FieldOptions: + return m.OldOptions(ctx) + case davaccount.FieldProps: + return m.OldProps(ctx) + case davaccount.FieldOwnerID: + return m.OldOwnerID(ctx) + } + return nil, fmt.Errorf("unknown DavAccount field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *DavAccountMutation) SetField(name string, value ent.Value) error { + switch name { + case davaccount.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case davaccount.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case davaccount.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case davaccount.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case davaccount.FieldURI: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetURI(v) + return nil + case davaccount.FieldPassword: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassword(v) + return nil + case davaccount.FieldOptions: + v, ok := value.(*boolset.BooleanSet) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOptions(v) + return nil + case davaccount.FieldProps: + v, ok := value.(*types.DavAccountProps) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProps(v) + return nil + case davaccount.FieldOwnerID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOwnerID(v) + return nil + } + return fmt.Errorf("unknown DavAccount field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *DavAccountMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *DavAccountMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *DavAccountMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown DavAccount numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *DavAccountMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(davaccount.FieldDeletedAt) { + fields = append(fields, davaccount.FieldDeletedAt) + } + if m.FieldCleared(davaccount.FieldProps) { + fields = append(fields, davaccount.FieldProps) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *DavAccountMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *DavAccountMutation) ClearField(name string) error { + switch name { + case davaccount.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case davaccount.FieldProps: + m.ClearProps() + return nil + } + return fmt.Errorf("unknown DavAccount nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *DavAccountMutation) ResetField(name string) error { + switch name { + case davaccount.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case davaccount.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case davaccount.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case davaccount.FieldName: + m.ResetName() + return nil + case davaccount.FieldURI: + m.ResetURI() + return nil + case davaccount.FieldPassword: + m.ResetPassword() + return nil + case davaccount.FieldOptions: + m.ResetOptions() + return nil + case davaccount.FieldProps: + m.ResetProps() + return nil + case davaccount.FieldOwnerID: + m.ResetOwnerID() + return nil + } + return fmt.Errorf("unknown DavAccount field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *DavAccountMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.owner != nil { + edges = append(edges, davaccount.EdgeOwner) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *DavAccountMutation) AddedIDs(name string) []ent.Value { + switch name { + case davaccount.EdgeOwner: + if id := m.owner; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *DavAccountMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *DavAccountMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *DavAccountMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedowner { + edges = append(edges, davaccount.EdgeOwner) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *DavAccountMutation) EdgeCleared(name string) bool { + switch name { + case davaccount.EdgeOwner: + return m.clearedowner + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *DavAccountMutation) ClearEdge(name string) error { + switch name { + case davaccount.EdgeOwner: + m.ClearOwner() + return nil + } + return fmt.Errorf("unknown DavAccount unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *DavAccountMutation) ResetEdge(name string) error { + switch name { + case davaccount.EdgeOwner: + m.ResetOwner() + return nil + } + return fmt.Errorf("unknown DavAccount edge %s", name) +} + +// DirectLinkMutation represents an operation that mutates the DirectLink nodes in the graph. +type DirectLinkMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + downloads *int + adddownloads *int + speed *int + addspeed *int + clearedFields map[string]struct{} + file *int + clearedfile bool + done bool + oldValue func(context.Context) (*DirectLink, error) + predicates []predicate.DirectLink +} + +var _ ent.Mutation = (*DirectLinkMutation)(nil) + +// directlinkOption allows management of the mutation configuration using functional options. +type directlinkOption func(*DirectLinkMutation) + +// newDirectLinkMutation creates new mutation for the DirectLink entity. +func newDirectLinkMutation(c config, op Op, opts ...directlinkOption) *DirectLinkMutation { + m := &DirectLinkMutation{ + config: c, + op: op, + typ: TypeDirectLink, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withDirectLinkID sets the ID field of the mutation. +func withDirectLinkID(id int) directlinkOption { + return func(m *DirectLinkMutation) { + var ( + err error + once sync.Once + value *DirectLink + ) + m.oldValue = func(ctx context.Context) (*DirectLink, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().DirectLink.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withDirectLink sets the old DirectLink of the mutation. +func withDirectLink(node *DirectLink) directlinkOption { + return func(m *DirectLinkMutation) { + m.oldValue = func(context.Context) (*DirectLink, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m DirectLinkMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m DirectLinkMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *DirectLinkMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *DirectLinkMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().DirectLink.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *DirectLinkMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *DirectLinkMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *DirectLinkMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *DirectLinkMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *DirectLinkMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *DirectLinkMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *DirectLinkMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *DirectLinkMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *DirectLinkMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[directlink.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *DirectLinkMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[directlink.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *DirectLinkMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, directlink.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *DirectLinkMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *DirectLinkMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *DirectLinkMutation) ResetName() { + m.name = nil +} + +// SetDownloads sets the "downloads" field. +func (m *DirectLinkMutation) SetDownloads(i int) { + m.downloads = &i + m.adddownloads = nil +} + +// Downloads returns the value of the "downloads" field in the mutation. +func (m *DirectLinkMutation) Downloads() (r int, exists bool) { + v := m.downloads + if v == nil { + return + } + return *v, true +} + +// OldDownloads returns the old "downloads" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldDownloads(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDownloads is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDownloads requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDownloads: %w", err) + } + return oldValue.Downloads, nil +} + +// AddDownloads adds i to the "downloads" field. +func (m *DirectLinkMutation) AddDownloads(i int) { + if m.adddownloads != nil { + *m.adddownloads += i + } else { + m.adddownloads = &i + } +} + +// AddedDownloads returns the value that was added to the "downloads" field in this mutation. +func (m *DirectLinkMutation) AddedDownloads() (r int, exists bool) { + v := m.adddownloads + if v == nil { + return + } + return *v, true +} + +// ResetDownloads resets all changes to the "downloads" field. +func (m *DirectLinkMutation) ResetDownloads() { + m.downloads = nil + m.adddownloads = nil +} + +// SetFileID sets the "file_id" field. +func (m *DirectLinkMutation) SetFileID(i int) { + m.file = &i +} + +// FileID returns the value of the "file_id" field in the mutation. +func (m *DirectLinkMutation) FileID() (r int, exists bool) { + v := m.file + if v == nil { + return + } + return *v, true +} + +// OldFileID returns the old "file_id" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldFileID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFileID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFileID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFileID: %w", err) + } + return oldValue.FileID, nil +} + +// ResetFileID resets all changes to the "file_id" field. +func (m *DirectLinkMutation) ResetFileID() { + m.file = nil +} + +// SetSpeed sets the "speed" field. +func (m *DirectLinkMutation) SetSpeed(i int) { + m.speed = &i + m.addspeed = nil +} + +// Speed returns the value of the "speed" field in the mutation. +func (m *DirectLinkMutation) Speed() (r int, exists bool) { + v := m.speed + if v == nil { + return + } + return *v, true +} + +// OldSpeed returns the old "speed" field's value of the DirectLink entity. +// If the DirectLink object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DirectLinkMutation) OldSpeed(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSpeed is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSpeed requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSpeed: %w", err) + } + return oldValue.Speed, nil +} + +// AddSpeed adds i to the "speed" field. +func (m *DirectLinkMutation) AddSpeed(i int) { + if m.addspeed != nil { + *m.addspeed += i + } else { + m.addspeed = &i + } +} + +// AddedSpeed returns the value that was added to the "speed" field in this mutation. +func (m *DirectLinkMutation) AddedSpeed() (r int, exists bool) { + v := m.addspeed + if v == nil { + return + } + return *v, true +} + +// ResetSpeed resets all changes to the "speed" field. +func (m *DirectLinkMutation) ResetSpeed() { + m.speed = nil + m.addspeed = nil +} + +// ClearFile clears the "file" edge to the File entity. +func (m *DirectLinkMutation) ClearFile() { + m.clearedfile = true + m.clearedFields[directlink.FieldFileID] = struct{}{} +} + +// FileCleared reports if the "file" edge to the File entity was cleared. +func (m *DirectLinkMutation) FileCleared() bool { + return m.clearedfile +} + +// FileIDs returns the "file" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// FileID instead. It exists only for internal usage by the builders. +func (m *DirectLinkMutation) FileIDs() (ids []int) { + if id := m.file; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetFile resets all changes to the "file" edge. +func (m *DirectLinkMutation) ResetFile() { + m.file = nil + m.clearedfile = false +} + +// Where appends a list predicates to the DirectLinkMutation builder. +func (m *DirectLinkMutation) Where(ps ...predicate.DirectLink) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the DirectLinkMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DirectLinkMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.DirectLink, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *DirectLinkMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *DirectLinkMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (DirectLink). +func (m *DirectLinkMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *DirectLinkMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, directlink.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, directlink.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, directlink.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, directlink.FieldName) + } + if m.downloads != nil { + fields = append(fields, directlink.FieldDownloads) + } + if m.file != nil { + fields = append(fields, directlink.FieldFileID) + } + if m.speed != nil { + fields = append(fields, directlink.FieldSpeed) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *DirectLinkMutation) Field(name string) (ent.Value, bool) { + switch name { + case directlink.FieldCreatedAt: + return m.CreatedAt() + case directlink.FieldUpdatedAt: + return m.UpdatedAt() + case directlink.FieldDeletedAt: + return m.DeletedAt() + case directlink.FieldName: + return m.Name() + case directlink.FieldDownloads: + return m.Downloads() + case directlink.FieldFileID: + return m.FileID() + case directlink.FieldSpeed: + return m.Speed() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *DirectLinkMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case directlink.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case directlink.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case directlink.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case directlink.FieldName: + return m.OldName(ctx) + case directlink.FieldDownloads: + return m.OldDownloads(ctx) + case directlink.FieldFileID: + return m.OldFileID(ctx) + case directlink.FieldSpeed: + return m.OldSpeed(ctx) + } + return nil, fmt.Errorf("unknown DirectLink field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *DirectLinkMutation) SetField(name string, value ent.Value) error { + switch name { + case directlink.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case directlink.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case directlink.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case directlink.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case directlink.FieldDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDownloads(v) + return nil + case directlink.FieldFileID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFileID(v) + return nil + case directlink.FieldSpeed: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSpeed(v) + return nil + } + return fmt.Errorf("unknown DirectLink field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *DirectLinkMutation) AddedFields() []string { + var fields []string + if m.adddownloads != nil { + fields = append(fields, directlink.FieldDownloads) + } + if m.addspeed != nil { + fields = append(fields, directlink.FieldSpeed) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *DirectLinkMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case directlink.FieldDownloads: + return m.AddedDownloads() + case directlink.FieldSpeed: + return m.AddedSpeed() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *DirectLinkMutation) AddField(name string, value ent.Value) error { + switch name { + case directlink.FieldDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDownloads(v) + return nil + case directlink.FieldSpeed: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSpeed(v) + return nil + } + return fmt.Errorf("unknown DirectLink numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *DirectLinkMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(directlink.FieldDeletedAt) { + fields = append(fields, directlink.FieldDeletedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *DirectLinkMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *DirectLinkMutation) ClearField(name string) error { + switch name { + case directlink.FieldDeletedAt: + m.ClearDeletedAt() + return nil + } + return fmt.Errorf("unknown DirectLink nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *DirectLinkMutation) ResetField(name string) error { + switch name { + case directlink.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case directlink.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case directlink.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case directlink.FieldName: + m.ResetName() + return nil + case directlink.FieldDownloads: + m.ResetDownloads() + return nil + case directlink.FieldFileID: + m.ResetFileID() + return nil + case directlink.FieldSpeed: + m.ResetSpeed() + return nil + } + return fmt.Errorf("unknown DirectLink field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *DirectLinkMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.file != nil { + edges = append(edges, directlink.EdgeFile) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *DirectLinkMutation) AddedIDs(name string) []ent.Value { + switch name { + case directlink.EdgeFile: + if id := m.file; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *DirectLinkMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *DirectLinkMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *DirectLinkMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedfile { + edges = append(edges, directlink.EdgeFile) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *DirectLinkMutation) EdgeCleared(name string) bool { + switch name { + case directlink.EdgeFile: + return m.clearedfile + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *DirectLinkMutation) ClearEdge(name string) error { + switch name { + case directlink.EdgeFile: + m.ClearFile() + return nil + } + return fmt.Errorf("unknown DirectLink unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *DirectLinkMutation) ResetEdge(name string) error { + switch name { + case directlink.EdgeFile: + m.ResetFile() + return nil + } + return fmt.Errorf("unknown DirectLink edge %s", name) +} + +// EntityMutation represents an operation that mutates the Entity nodes in the graph. +type EntityMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + _type *int + add_type *int + source *string + size *int64 + addsize *int64 + reference_count *int + addreference_count *int + upload_session_id *uuid.UUID + recycle_options **types.EntityRecycleOption + clearedFields map[string]struct{} + file map[int]struct{} + removedfile map[int]struct{} + clearedfile bool + user *int + cleareduser bool + storage_policy *int + clearedstorage_policy bool + done bool + oldValue func(context.Context) (*Entity, error) + predicates []predicate.Entity +} + +var _ ent.Mutation = (*EntityMutation)(nil) + +// entityOption allows management of the mutation configuration using functional options. +type entityOption func(*EntityMutation) + +// newEntityMutation creates new mutation for the Entity entity. +func newEntityMutation(c config, op Op, opts ...entityOption) *EntityMutation { + m := &EntityMutation{ + config: c, + op: op, + typ: TypeEntity, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withEntityID sets the ID field of the mutation. +func withEntityID(id int) entityOption { + return func(m *EntityMutation) { + var ( + err error + once sync.Once + value *Entity + ) + m.oldValue = func(ctx context.Context) (*Entity, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Entity.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withEntity sets the old Entity of the mutation. +func withEntity(node *Entity) entityOption { + return func(m *EntityMutation) { + m.oldValue = func(context.Context) (*Entity, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m EntityMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m EntityMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *EntityMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *EntityMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Entity.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *EntityMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *EntityMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *EntityMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *EntityMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *EntityMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *EntityMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *EntityMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *EntityMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *EntityMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[entity.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *EntityMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[entity.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *EntityMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, entity.FieldDeletedAt) +} + +// SetType sets the "type" field. +func (m *EntityMutation) SetType(i int) { + m._type = &i + m.add_type = nil +} + +// GetType returns the value of the "type" field in the mutation. +func (m *EntityMutation) GetType() (r int, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldType(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// AddType adds i to the "type" field. +func (m *EntityMutation) AddType(i int) { + if m.add_type != nil { + *m.add_type += i + } else { + m.add_type = &i + } +} + +// AddedType returns the value that was added to the "type" field in this mutation. +func (m *EntityMutation) AddedType() (r int, exists bool) { + v := m.add_type + if v == nil { + return + } + return *v, true +} + +// ResetType resets all changes to the "type" field. +func (m *EntityMutation) ResetType() { + m._type = nil + m.add_type = nil +} + +// SetSource sets the "source" field. +func (m *EntityMutation) SetSource(s string) { + m.source = &s +} + +// Source returns the value of the "source" field in the mutation. +func (m *EntityMutation) Source() (r string, exists bool) { + v := m.source + if v == nil { + return + } + return *v, true +} + +// OldSource returns the old "source" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSource: %w", err) + } + return oldValue.Source, nil +} + +// ResetSource resets all changes to the "source" field. +func (m *EntityMutation) ResetSource() { + m.source = nil +} + +// SetSize sets the "size" field. +func (m *EntityMutation) SetSize(i int64) { + m.size = &i + m.addsize = nil +} + +// Size returns the value of the "size" field in the mutation. +func (m *EntityMutation) Size() (r int64, exists bool) { + v := m.size + if v == nil { + return + } + return *v, true +} + +// OldSize returns the old "size" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldSize(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSize: %w", err) + } + return oldValue.Size, nil +} + +// AddSize adds i to the "size" field. +func (m *EntityMutation) AddSize(i int64) { + if m.addsize != nil { + *m.addsize += i + } else { + m.addsize = &i + } +} + +// AddedSize returns the value that was added to the "size" field in this mutation. +func (m *EntityMutation) AddedSize() (r int64, exists bool) { + v := m.addsize + if v == nil { + return + } + return *v, true +} + +// ResetSize resets all changes to the "size" field. +func (m *EntityMutation) ResetSize() { + m.size = nil + m.addsize = nil +} + +// SetReferenceCount sets the "reference_count" field. +func (m *EntityMutation) SetReferenceCount(i int) { + m.reference_count = &i + m.addreference_count = nil +} + +// ReferenceCount returns the value of the "reference_count" field in the mutation. +func (m *EntityMutation) ReferenceCount() (r int, exists bool) { + v := m.reference_count + if v == nil { + return + } + return *v, true +} + +// OldReferenceCount returns the old "reference_count" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldReferenceCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldReferenceCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldReferenceCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldReferenceCount: %w", err) + } + return oldValue.ReferenceCount, nil +} + +// AddReferenceCount adds i to the "reference_count" field. +func (m *EntityMutation) AddReferenceCount(i int) { + if m.addreference_count != nil { + *m.addreference_count += i + } else { + m.addreference_count = &i + } +} + +// AddedReferenceCount returns the value that was added to the "reference_count" field in this mutation. +func (m *EntityMutation) AddedReferenceCount() (r int, exists bool) { + v := m.addreference_count + if v == nil { + return + } + return *v, true +} + +// ResetReferenceCount resets all changes to the "reference_count" field. +func (m *EntityMutation) ResetReferenceCount() { + m.reference_count = nil + m.addreference_count = nil +} + +// SetStoragePolicyEntities sets the "storage_policy_entities" field. +func (m *EntityMutation) SetStoragePolicyEntities(i int) { + m.storage_policy = &i +} + +// StoragePolicyEntities returns the value of the "storage_policy_entities" field in the mutation. +func (m *EntityMutation) StoragePolicyEntities() (r int, exists bool) { + v := m.storage_policy + if v == nil { + return + } + return *v, true +} + +// OldStoragePolicyEntities returns the old "storage_policy_entities" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldStoragePolicyEntities(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStoragePolicyEntities is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStoragePolicyEntities requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStoragePolicyEntities: %w", err) + } + return oldValue.StoragePolicyEntities, nil +} + +// ResetStoragePolicyEntities resets all changes to the "storage_policy_entities" field. +func (m *EntityMutation) ResetStoragePolicyEntities() { + m.storage_policy = nil +} + +// SetCreatedBy sets the "created_by" field. +func (m *EntityMutation) SetCreatedBy(i int) { + m.user = &i +} + +// CreatedBy returns the value of the "created_by" field in the mutation. +func (m *EntityMutation) CreatedBy() (r int, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldCreatedBy returns the old "created_by" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldCreatedBy(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err) + } + return oldValue.CreatedBy, nil +} + +// ClearCreatedBy clears the value of the "created_by" field. +func (m *EntityMutation) ClearCreatedBy() { + m.user = nil + m.clearedFields[entity.FieldCreatedBy] = struct{}{} +} + +// CreatedByCleared returns if the "created_by" field was cleared in this mutation. +func (m *EntityMutation) CreatedByCleared() bool { + _, ok := m.clearedFields[entity.FieldCreatedBy] + return ok +} + +// ResetCreatedBy resets all changes to the "created_by" field. +func (m *EntityMutation) ResetCreatedBy() { + m.user = nil + delete(m.clearedFields, entity.FieldCreatedBy) +} + +// SetUploadSessionID sets the "upload_session_id" field. +func (m *EntityMutation) SetUploadSessionID(u uuid.UUID) { + m.upload_session_id = &u +} + +// UploadSessionID returns the value of the "upload_session_id" field in the mutation. +func (m *EntityMutation) UploadSessionID() (r uuid.UUID, exists bool) { + v := m.upload_session_id + if v == nil { + return + } + return *v, true +} + +// OldUploadSessionID returns the old "upload_session_id" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldUploadSessionID(ctx context.Context) (v *uuid.UUID, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUploadSessionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUploadSessionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUploadSessionID: %w", err) + } + return oldValue.UploadSessionID, nil +} + +// ClearUploadSessionID clears the value of the "upload_session_id" field. +func (m *EntityMutation) ClearUploadSessionID() { + m.upload_session_id = nil + m.clearedFields[entity.FieldUploadSessionID] = struct{}{} +} + +// UploadSessionIDCleared returns if the "upload_session_id" field was cleared in this mutation. +func (m *EntityMutation) UploadSessionIDCleared() bool { + _, ok := m.clearedFields[entity.FieldUploadSessionID] + return ok +} + +// ResetUploadSessionID resets all changes to the "upload_session_id" field. +func (m *EntityMutation) ResetUploadSessionID() { + m.upload_session_id = nil + delete(m.clearedFields, entity.FieldUploadSessionID) +} + +// SetRecycleOptions sets the "recycle_options" field. +func (m *EntityMutation) SetRecycleOptions(tro *types.EntityRecycleOption) { + m.recycle_options = &tro +} + +// RecycleOptions returns the value of the "recycle_options" field in the mutation. +func (m *EntityMutation) RecycleOptions() (r *types.EntityRecycleOption, exists bool) { + v := m.recycle_options + if v == nil { + return + } + return *v, true +} + +// OldRecycleOptions returns the old "recycle_options" field's value of the Entity entity. +// If the Entity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *EntityMutation) OldRecycleOptions(ctx context.Context) (v *types.EntityRecycleOption, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRecycleOptions is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRecycleOptions requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRecycleOptions: %w", err) + } + return oldValue.RecycleOptions, nil +} + +// ClearRecycleOptions clears the value of the "recycle_options" field. +func (m *EntityMutation) ClearRecycleOptions() { + m.recycle_options = nil + m.clearedFields[entity.FieldRecycleOptions] = struct{}{} +} + +// RecycleOptionsCleared returns if the "recycle_options" field was cleared in this mutation. +func (m *EntityMutation) RecycleOptionsCleared() bool { + _, ok := m.clearedFields[entity.FieldRecycleOptions] + return ok +} + +// ResetRecycleOptions resets all changes to the "recycle_options" field. +func (m *EntityMutation) ResetRecycleOptions() { + m.recycle_options = nil + delete(m.clearedFields, entity.FieldRecycleOptions) +} + +// AddFileIDs adds the "file" edge to the File entity by ids. +func (m *EntityMutation) AddFileIDs(ids ...int) { + if m.file == nil { + m.file = make(map[int]struct{}) + } + for i := range ids { + m.file[ids[i]] = struct{}{} + } +} + +// ClearFile clears the "file" edge to the File entity. +func (m *EntityMutation) ClearFile() { + m.clearedfile = true +} + +// FileCleared reports if the "file" edge to the File entity was cleared. +func (m *EntityMutation) FileCleared() bool { + return m.clearedfile +} + +// RemoveFileIDs removes the "file" edge to the File entity by IDs. +func (m *EntityMutation) RemoveFileIDs(ids ...int) { + if m.removedfile == nil { + m.removedfile = make(map[int]struct{}) + } + for i := range ids { + delete(m.file, ids[i]) + m.removedfile[ids[i]] = struct{}{} + } +} + +// RemovedFile returns the removed IDs of the "file" edge to the File entity. +func (m *EntityMutation) RemovedFileIDs() (ids []int) { + for id := range m.removedfile { + ids = append(ids, id) + } + return +} + +// FileIDs returns the "file" edge IDs in the mutation. +func (m *EntityMutation) FileIDs() (ids []int) { + for id := range m.file { + ids = append(ids, id) + } + return +} + +// ResetFile resets all changes to the "file" edge. +func (m *EntityMutation) ResetFile() { + m.file = nil + m.clearedfile = false + m.removedfile = nil +} + +// SetUserID sets the "user" edge to the User entity by id. +func (m *EntityMutation) SetUserID(id int) { + m.user = &id +} + +// ClearUser clears the "user" edge to the User entity. +func (m *EntityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[entity.FieldCreatedBy] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *EntityMutation) UserCleared() bool { + return m.CreatedByCleared() || m.cleareduser +} + +// UserID returns the "user" edge ID in the mutation. +func (m *EntityMutation) UserID() (id int, exists bool) { + if m.user != nil { + return *m.user, true + } + return +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *EntityMutation) UserIDs() (ids []int) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *EntityMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// SetStoragePolicyID sets the "storage_policy" edge to the StoragePolicy entity by id. +func (m *EntityMutation) SetStoragePolicyID(id int) { + m.storage_policy = &id +} + +// ClearStoragePolicy clears the "storage_policy" edge to the StoragePolicy entity. +func (m *EntityMutation) ClearStoragePolicy() { + m.clearedstorage_policy = true + m.clearedFields[entity.FieldStoragePolicyEntities] = struct{}{} +} + +// StoragePolicyCleared reports if the "storage_policy" edge to the StoragePolicy entity was cleared. +func (m *EntityMutation) StoragePolicyCleared() bool { + return m.clearedstorage_policy +} + +// StoragePolicyID returns the "storage_policy" edge ID in the mutation. +func (m *EntityMutation) StoragePolicyID() (id int, exists bool) { + if m.storage_policy != nil { + return *m.storage_policy, true + } + return +} + +// StoragePolicyIDs returns the "storage_policy" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// StoragePolicyID instead. It exists only for internal usage by the builders. +func (m *EntityMutation) StoragePolicyIDs() (ids []int) { + if id := m.storage_policy; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetStoragePolicy resets all changes to the "storage_policy" edge. +func (m *EntityMutation) ResetStoragePolicy() { + m.storage_policy = nil + m.clearedstorage_policy = false +} + +// Where appends a list predicates to the EntityMutation builder. +func (m *EntityMutation) Where(ps ...predicate.Entity) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the EntityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *EntityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Entity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *EntityMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *EntityMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Entity). +func (m *EntityMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *EntityMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, entity.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, entity.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, entity.FieldDeletedAt) + } + if m._type != nil { + fields = append(fields, entity.FieldType) + } + if m.source != nil { + fields = append(fields, entity.FieldSource) + } + if m.size != nil { + fields = append(fields, entity.FieldSize) + } + if m.reference_count != nil { + fields = append(fields, entity.FieldReferenceCount) + } + if m.storage_policy != nil { + fields = append(fields, entity.FieldStoragePolicyEntities) + } + if m.user != nil { + fields = append(fields, entity.FieldCreatedBy) + } + if m.upload_session_id != nil { + fields = append(fields, entity.FieldUploadSessionID) + } + if m.recycle_options != nil { + fields = append(fields, entity.FieldRecycleOptions) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *EntityMutation) Field(name string) (ent.Value, bool) { + switch name { + case entity.FieldCreatedAt: + return m.CreatedAt() + case entity.FieldUpdatedAt: + return m.UpdatedAt() + case entity.FieldDeletedAt: + return m.DeletedAt() + case entity.FieldType: + return m.GetType() + case entity.FieldSource: + return m.Source() + case entity.FieldSize: + return m.Size() + case entity.FieldReferenceCount: + return m.ReferenceCount() + case entity.FieldStoragePolicyEntities: + return m.StoragePolicyEntities() + case entity.FieldCreatedBy: + return m.CreatedBy() + case entity.FieldUploadSessionID: + return m.UploadSessionID() + case entity.FieldRecycleOptions: + return m.RecycleOptions() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *EntityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case entity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case entity.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case entity.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case entity.FieldType: + return m.OldType(ctx) + case entity.FieldSource: + return m.OldSource(ctx) + case entity.FieldSize: + return m.OldSize(ctx) + case entity.FieldReferenceCount: + return m.OldReferenceCount(ctx) + case entity.FieldStoragePolicyEntities: + return m.OldStoragePolicyEntities(ctx) + case entity.FieldCreatedBy: + return m.OldCreatedBy(ctx) + case entity.FieldUploadSessionID: + return m.OldUploadSessionID(ctx) + case entity.FieldRecycleOptions: + return m.OldRecycleOptions(ctx) + } + return nil, fmt.Errorf("unknown Entity field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *EntityMutation) SetField(name string, value ent.Value) error { + switch name { + case entity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case entity.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case entity.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case entity.FieldType: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case entity.FieldSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSource(v) + return nil + case entity.FieldSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSize(v) + return nil + case entity.FieldReferenceCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetReferenceCount(v) + return nil + case entity.FieldStoragePolicyEntities: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStoragePolicyEntities(v) + return nil + case entity.FieldCreatedBy: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedBy(v) + return nil + case entity.FieldUploadSessionID: + v, ok := value.(uuid.UUID) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUploadSessionID(v) + return nil + case entity.FieldRecycleOptions: + v, ok := value.(*types.EntityRecycleOption) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRecycleOptions(v) + return nil + } + return fmt.Errorf("unknown Entity field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *EntityMutation) AddedFields() []string { + var fields []string + if m.add_type != nil { + fields = append(fields, entity.FieldType) + } + if m.addsize != nil { + fields = append(fields, entity.FieldSize) + } + if m.addreference_count != nil { + fields = append(fields, entity.FieldReferenceCount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *EntityMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case entity.FieldType: + return m.AddedType() + case entity.FieldSize: + return m.AddedSize() + case entity.FieldReferenceCount: + return m.AddedReferenceCount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *EntityMutation) AddField(name string, value ent.Value) error { + switch name { + case entity.FieldType: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddType(v) + return nil + case entity.FieldSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSize(v) + return nil + case entity.FieldReferenceCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddReferenceCount(v) + return nil + } + return fmt.Errorf("unknown Entity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *EntityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(entity.FieldDeletedAt) { + fields = append(fields, entity.FieldDeletedAt) + } + if m.FieldCleared(entity.FieldCreatedBy) { + fields = append(fields, entity.FieldCreatedBy) + } + if m.FieldCleared(entity.FieldUploadSessionID) { + fields = append(fields, entity.FieldUploadSessionID) + } + if m.FieldCleared(entity.FieldRecycleOptions) { + fields = append(fields, entity.FieldRecycleOptions) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *EntityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *EntityMutation) ClearField(name string) error { + switch name { + case entity.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case entity.FieldCreatedBy: + m.ClearCreatedBy() + return nil + case entity.FieldUploadSessionID: + m.ClearUploadSessionID() + return nil + case entity.FieldRecycleOptions: + m.ClearRecycleOptions() + return nil + } + return fmt.Errorf("unknown Entity nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *EntityMutation) ResetField(name string) error { + switch name { + case entity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case entity.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case entity.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case entity.FieldType: + m.ResetType() + return nil + case entity.FieldSource: + m.ResetSource() + return nil + case entity.FieldSize: + m.ResetSize() + return nil + case entity.FieldReferenceCount: + m.ResetReferenceCount() + return nil + case entity.FieldStoragePolicyEntities: + m.ResetStoragePolicyEntities() + return nil + case entity.FieldCreatedBy: + m.ResetCreatedBy() + return nil + case entity.FieldUploadSessionID: + m.ResetUploadSessionID() + return nil + case entity.FieldRecycleOptions: + m.ResetRecycleOptions() + return nil + } + return fmt.Errorf("unknown Entity field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *EntityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.file != nil { + edges = append(edges, entity.EdgeFile) + } + if m.user != nil { + edges = append(edges, entity.EdgeUser) + } + if m.storage_policy != nil { + edges = append(edges, entity.EdgeStoragePolicy) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *EntityMutation) AddedIDs(name string) []ent.Value { + switch name { + case entity.EdgeFile: + ids := make([]ent.Value, 0, len(m.file)) + for id := range m.file { + ids = append(ids, id) + } + return ids + case entity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case entity.EdgeStoragePolicy: + if id := m.storage_policy; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *EntityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedfile != nil { + edges = append(edges, entity.EdgeFile) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *EntityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case entity.EdgeFile: + ids := make([]ent.Value, 0, len(m.removedfile)) + for id := range m.removedfile { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *EntityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.clearedfile { + edges = append(edges, entity.EdgeFile) + } + if m.cleareduser { + edges = append(edges, entity.EdgeUser) + } + if m.clearedstorage_policy { + edges = append(edges, entity.EdgeStoragePolicy) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *EntityMutation) EdgeCleared(name string) bool { + switch name { + case entity.EdgeFile: + return m.clearedfile + case entity.EdgeUser: + return m.cleareduser + case entity.EdgeStoragePolicy: + return m.clearedstorage_policy + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *EntityMutation) ClearEdge(name string) error { + switch name { + case entity.EdgeUser: + m.ClearUser() + return nil + case entity.EdgeStoragePolicy: + m.ClearStoragePolicy() + return nil + } + return fmt.Errorf("unknown Entity unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *EntityMutation) ResetEdge(name string) error { + switch name { + case entity.EdgeFile: + m.ResetFile() + return nil + case entity.EdgeUser: + m.ResetUser() + return nil + case entity.EdgeStoragePolicy: + m.ResetStoragePolicy() + return nil + } + return fmt.Errorf("unknown Entity edge %s", name) +} + +// FileMutation represents an operation that mutates the File nodes in the graph. +type FileMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + _type *int + add_type *int + name *string + size *int64 + addsize *int64 + primary_entity *int + addprimary_entity *int + is_symbolic *bool + props **types.FileProps + clearedFields map[string]struct{} + owner *int + clearedowner bool + storage_policies *int + clearedstorage_policies bool + parent *int + clearedparent bool + children map[int]struct{} + removedchildren map[int]struct{} + clearedchildren bool + metadata map[int]struct{} + removedmetadata map[int]struct{} + clearedmetadata bool + entities map[int]struct{} + removedentities map[int]struct{} + clearedentities bool + shares map[int]struct{} + removedshares map[int]struct{} + clearedshares bool + direct_links map[int]struct{} + removeddirect_links map[int]struct{} + cleareddirect_links bool + done bool + oldValue func(context.Context) (*File, error) + predicates []predicate.File +} + +var _ ent.Mutation = (*FileMutation)(nil) + +// fileOption allows management of the mutation configuration using functional options. +type fileOption func(*FileMutation) + +// newFileMutation creates new mutation for the File entity. +func newFileMutation(c config, op Op, opts ...fileOption) *FileMutation { + m := &FileMutation{ + config: c, + op: op, + typ: TypeFile, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withFileID sets the ID field of the mutation. +func withFileID(id int) fileOption { + return func(m *FileMutation) { + var ( + err error + once sync.Once + value *File + ) + m.oldValue = func(ctx context.Context) (*File, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().File.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withFile sets the old File of the mutation. +func withFile(node *File) fileOption { + return func(m *FileMutation) { + m.oldValue = func(context.Context) (*File, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m FileMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m FileMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *FileMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *FileMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().File.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *FileMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *FileMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *FileMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *FileMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *FileMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *FileMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *FileMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *FileMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *FileMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[file.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *FileMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[file.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *FileMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, file.FieldDeletedAt) +} + +// SetType sets the "type" field. +func (m *FileMutation) SetType(i int) { + m._type = &i + m.add_type = nil +} + +// GetType returns the value of the "type" field in the mutation. +func (m *FileMutation) GetType() (r int, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldType(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// AddType adds i to the "type" field. +func (m *FileMutation) AddType(i int) { + if m.add_type != nil { + *m.add_type += i + } else { + m.add_type = &i + } +} + +// AddedType returns the value that was added to the "type" field in this mutation. +func (m *FileMutation) AddedType() (r int, exists bool) { + v := m.add_type + if v == nil { + return + } + return *v, true +} + +// ResetType resets all changes to the "type" field. +func (m *FileMutation) ResetType() { + m._type = nil + m.add_type = nil +} + +// SetName sets the "name" field. +func (m *FileMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *FileMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *FileMutation) ResetName() { + m.name = nil +} + +// SetOwnerID sets the "owner_id" field. +func (m *FileMutation) SetOwnerID(i int) { + m.owner = &i +} + +// OwnerID returns the value of the "owner_id" field in the mutation. +func (m *FileMutation) OwnerID() (r int, exists bool) { + v := m.owner + if v == nil { + return + } + return *v, true +} + +// OldOwnerID returns the old "owner_id" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldOwnerID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOwnerID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOwnerID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOwnerID: %w", err) + } + return oldValue.OwnerID, nil +} + +// ResetOwnerID resets all changes to the "owner_id" field. +func (m *FileMutation) ResetOwnerID() { + m.owner = nil +} + +// SetSize sets the "size" field. +func (m *FileMutation) SetSize(i int64) { + m.size = &i + m.addsize = nil +} + +// Size returns the value of the "size" field in the mutation. +func (m *FileMutation) Size() (r int64, exists bool) { + v := m.size + if v == nil { + return + } + return *v, true +} + +// OldSize returns the old "size" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldSize(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSize: %w", err) + } + return oldValue.Size, nil +} + +// AddSize adds i to the "size" field. +func (m *FileMutation) AddSize(i int64) { + if m.addsize != nil { + *m.addsize += i + } else { + m.addsize = &i + } +} + +// AddedSize returns the value that was added to the "size" field in this mutation. +func (m *FileMutation) AddedSize() (r int64, exists bool) { + v := m.addsize + if v == nil { + return + } + return *v, true +} + +// ResetSize resets all changes to the "size" field. +func (m *FileMutation) ResetSize() { + m.size = nil + m.addsize = nil +} + +// SetPrimaryEntity sets the "primary_entity" field. +func (m *FileMutation) SetPrimaryEntity(i int) { + m.primary_entity = &i + m.addprimary_entity = nil +} + +// PrimaryEntity returns the value of the "primary_entity" field in the mutation. +func (m *FileMutation) PrimaryEntity() (r int, exists bool) { + v := m.primary_entity + if v == nil { + return + } + return *v, true +} + +// OldPrimaryEntity returns the old "primary_entity" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldPrimaryEntity(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrimaryEntity is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrimaryEntity requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrimaryEntity: %w", err) + } + return oldValue.PrimaryEntity, nil +} + +// AddPrimaryEntity adds i to the "primary_entity" field. +func (m *FileMutation) AddPrimaryEntity(i int) { + if m.addprimary_entity != nil { + *m.addprimary_entity += i + } else { + m.addprimary_entity = &i + } +} + +// AddedPrimaryEntity returns the value that was added to the "primary_entity" field in this mutation. +func (m *FileMutation) AddedPrimaryEntity() (r int, exists bool) { + v := m.addprimary_entity + if v == nil { + return + } + return *v, true +} + +// ClearPrimaryEntity clears the value of the "primary_entity" field. +func (m *FileMutation) ClearPrimaryEntity() { + m.primary_entity = nil + m.addprimary_entity = nil + m.clearedFields[file.FieldPrimaryEntity] = struct{}{} +} + +// PrimaryEntityCleared returns if the "primary_entity" field was cleared in this mutation. +func (m *FileMutation) PrimaryEntityCleared() bool { + _, ok := m.clearedFields[file.FieldPrimaryEntity] + return ok +} + +// ResetPrimaryEntity resets all changes to the "primary_entity" field. +func (m *FileMutation) ResetPrimaryEntity() { + m.primary_entity = nil + m.addprimary_entity = nil + delete(m.clearedFields, file.FieldPrimaryEntity) +} + +// SetFileChildren sets the "file_children" field. +func (m *FileMutation) SetFileChildren(i int) { + m.parent = &i +} + +// FileChildren returns the value of the "file_children" field in the mutation. +func (m *FileMutation) FileChildren() (r int, exists bool) { + v := m.parent + if v == nil { + return + } + return *v, true +} + +// OldFileChildren returns the old "file_children" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldFileChildren(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFileChildren is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFileChildren requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFileChildren: %w", err) + } + return oldValue.FileChildren, nil +} + +// ClearFileChildren clears the value of the "file_children" field. +func (m *FileMutation) ClearFileChildren() { + m.parent = nil + m.clearedFields[file.FieldFileChildren] = struct{}{} +} + +// FileChildrenCleared returns if the "file_children" field was cleared in this mutation. +func (m *FileMutation) FileChildrenCleared() bool { + _, ok := m.clearedFields[file.FieldFileChildren] + return ok +} + +// ResetFileChildren resets all changes to the "file_children" field. +func (m *FileMutation) ResetFileChildren() { + m.parent = nil + delete(m.clearedFields, file.FieldFileChildren) +} + +// SetIsSymbolic sets the "is_symbolic" field. +func (m *FileMutation) SetIsSymbolic(b bool) { + m.is_symbolic = &b +} + +// IsSymbolic returns the value of the "is_symbolic" field in the mutation. +func (m *FileMutation) IsSymbolic() (r bool, exists bool) { + v := m.is_symbolic + if v == nil { + return + } + return *v, true +} + +// OldIsSymbolic returns the old "is_symbolic" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldIsSymbolic(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsSymbolic is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsSymbolic requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsSymbolic: %w", err) + } + return oldValue.IsSymbolic, nil +} + +// ResetIsSymbolic resets all changes to the "is_symbolic" field. +func (m *FileMutation) ResetIsSymbolic() { + m.is_symbolic = nil +} + +// SetProps sets the "props" field. +func (m *FileMutation) SetProps(tp *types.FileProps) { + m.props = &tp +} + +// Props returns the value of the "props" field in the mutation. +func (m *FileMutation) Props() (r *types.FileProps, exists bool) { + v := m.props + if v == nil { + return + } + return *v, true +} + +// OldProps returns the old "props" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldProps(ctx context.Context) (v *types.FileProps, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProps is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProps requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProps: %w", err) + } + return oldValue.Props, nil +} + +// ClearProps clears the value of the "props" field. +func (m *FileMutation) ClearProps() { + m.props = nil + m.clearedFields[file.FieldProps] = struct{}{} +} + +// PropsCleared returns if the "props" field was cleared in this mutation. +func (m *FileMutation) PropsCleared() bool { + _, ok := m.clearedFields[file.FieldProps] + return ok +} + +// ResetProps resets all changes to the "props" field. +func (m *FileMutation) ResetProps() { + m.props = nil + delete(m.clearedFields, file.FieldProps) +} + +// SetStoragePolicyFiles sets the "storage_policy_files" field. +func (m *FileMutation) SetStoragePolicyFiles(i int) { + m.storage_policies = &i +} + +// StoragePolicyFiles returns the value of the "storage_policy_files" field in the mutation. +func (m *FileMutation) StoragePolicyFiles() (r int, exists bool) { + v := m.storage_policies + if v == nil { + return + } + return *v, true +} + +// OldStoragePolicyFiles returns the old "storage_policy_files" field's value of the File entity. +// If the File object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *FileMutation) OldStoragePolicyFiles(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStoragePolicyFiles is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStoragePolicyFiles requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStoragePolicyFiles: %w", err) + } + return oldValue.StoragePolicyFiles, nil +} + +// ClearStoragePolicyFiles clears the value of the "storage_policy_files" field. +func (m *FileMutation) ClearStoragePolicyFiles() { + m.storage_policies = nil + m.clearedFields[file.FieldStoragePolicyFiles] = struct{}{} +} + +// StoragePolicyFilesCleared returns if the "storage_policy_files" field was cleared in this mutation. +func (m *FileMutation) StoragePolicyFilesCleared() bool { + _, ok := m.clearedFields[file.FieldStoragePolicyFiles] + return ok +} + +// ResetStoragePolicyFiles resets all changes to the "storage_policy_files" field. +func (m *FileMutation) ResetStoragePolicyFiles() { + m.storage_policies = nil + delete(m.clearedFields, file.FieldStoragePolicyFiles) +} + +// ClearOwner clears the "owner" edge to the User entity. +func (m *FileMutation) ClearOwner() { + m.clearedowner = true + m.clearedFields[file.FieldOwnerID] = struct{}{} +} + +// OwnerCleared reports if the "owner" edge to the User entity was cleared. +func (m *FileMutation) OwnerCleared() bool { + return m.clearedowner +} + +// OwnerIDs returns the "owner" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// OwnerID instead. It exists only for internal usage by the builders. +func (m *FileMutation) OwnerIDs() (ids []int) { + if id := m.owner; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetOwner resets all changes to the "owner" edge. +func (m *FileMutation) ResetOwner() { + m.owner = nil + m.clearedowner = false +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by id. +func (m *FileMutation) SetStoragePoliciesID(id int) { + m.storage_policies = &id +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (m *FileMutation) ClearStoragePolicies() { + m.clearedstorage_policies = true + m.clearedFields[file.FieldStoragePolicyFiles] = struct{}{} +} + +// StoragePoliciesCleared reports if the "storage_policies" edge to the StoragePolicy entity was cleared. +func (m *FileMutation) StoragePoliciesCleared() bool { + return m.StoragePolicyFilesCleared() || m.clearedstorage_policies +} + +// StoragePoliciesID returns the "storage_policies" edge ID in the mutation. +func (m *FileMutation) StoragePoliciesID() (id int, exists bool) { + if m.storage_policies != nil { + return *m.storage_policies, true + } + return +} + +// StoragePoliciesIDs returns the "storage_policies" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// StoragePoliciesID instead. It exists only for internal usage by the builders. +func (m *FileMutation) StoragePoliciesIDs() (ids []int) { + if id := m.storage_policies; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetStoragePolicies resets all changes to the "storage_policies" edge. +func (m *FileMutation) ResetStoragePolicies() { + m.storage_policies = nil + m.clearedstorage_policies = false +} + +// SetParentID sets the "parent" edge to the File entity by id. +func (m *FileMutation) SetParentID(id int) { + m.parent = &id +} + +// ClearParent clears the "parent" edge to the File entity. +func (m *FileMutation) ClearParent() { + m.clearedparent = true + m.clearedFields[file.FieldFileChildren] = struct{}{} +} + +// ParentCleared reports if the "parent" edge to the File entity was cleared. +func (m *FileMutation) ParentCleared() bool { + return m.FileChildrenCleared() || m.clearedparent +} + +// ParentID returns the "parent" edge ID in the mutation. +func (m *FileMutation) ParentID() (id int, exists bool) { + if m.parent != nil { + return *m.parent, true + } + return +} + +// ParentIDs returns the "parent" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ParentID instead. It exists only for internal usage by the builders. +func (m *FileMutation) ParentIDs() (ids []int) { + if id := m.parent; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetParent resets all changes to the "parent" edge. +func (m *FileMutation) ResetParent() { + m.parent = nil + m.clearedparent = false +} + +// AddChildIDs adds the "children" edge to the File entity by ids. +func (m *FileMutation) AddChildIDs(ids ...int) { + if m.children == nil { + m.children = make(map[int]struct{}) + } + for i := range ids { + m.children[ids[i]] = struct{}{} + } +} + +// ClearChildren clears the "children" edge to the File entity. +func (m *FileMutation) ClearChildren() { + m.clearedchildren = true +} + +// ChildrenCleared reports if the "children" edge to the File entity was cleared. +func (m *FileMutation) ChildrenCleared() bool { + return m.clearedchildren +} + +// RemoveChildIDs removes the "children" edge to the File entity by IDs. +func (m *FileMutation) RemoveChildIDs(ids ...int) { + if m.removedchildren == nil { + m.removedchildren = make(map[int]struct{}) + } + for i := range ids { + delete(m.children, ids[i]) + m.removedchildren[ids[i]] = struct{}{} + } +} + +// RemovedChildren returns the removed IDs of the "children" edge to the File entity. +func (m *FileMutation) RemovedChildrenIDs() (ids []int) { + for id := range m.removedchildren { + ids = append(ids, id) + } + return +} + +// ChildrenIDs returns the "children" edge IDs in the mutation. +func (m *FileMutation) ChildrenIDs() (ids []int) { + for id := range m.children { + ids = append(ids, id) + } + return +} + +// ResetChildren resets all changes to the "children" edge. +func (m *FileMutation) ResetChildren() { + m.children = nil + m.clearedchildren = false + m.removedchildren = nil +} + +// AddMetadatumIDs adds the "metadata" edge to the Metadata entity by ids. +func (m *FileMutation) AddMetadatumIDs(ids ...int) { + if m.metadata == nil { + m.metadata = make(map[int]struct{}) + } + for i := range ids { + m.metadata[ids[i]] = struct{}{} + } +} + +// ClearMetadata clears the "metadata" edge to the Metadata entity. +func (m *FileMutation) ClearMetadata() { + m.clearedmetadata = true +} + +// MetadataCleared reports if the "metadata" edge to the Metadata entity was cleared. +func (m *FileMutation) MetadataCleared() bool { + return m.clearedmetadata +} + +// RemoveMetadatumIDs removes the "metadata" edge to the Metadata entity by IDs. +func (m *FileMutation) RemoveMetadatumIDs(ids ...int) { + if m.removedmetadata == nil { + m.removedmetadata = make(map[int]struct{}) + } + for i := range ids { + delete(m.metadata, ids[i]) + m.removedmetadata[ids[i]] = struct{}{} + } +} + +// RemovedMetadata returns the removed IDs of the "metadata" edge to the Metadata entity. +func (m *FileMutation) RemovedMetadataIDs() (ids []int) { + for id := range m.removedmetadata { + ids = append(ids, id) + } + return +} + +// MetadataIDs returns the "metadata" edge IDs in the mutation. +func (m *FileMutation) MetadataIDs() (ids []int) { + for id := range m.metadata { + ids = append(ids, id) + } + return +} + +// ResetMetadata resets all changes to the "metadata" edge. +func (m *FileMutation) ResetMetadata() { + m.metadata = nil + m.clearedmetadata = false + m.removedmetadata = nil +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by ids. +func (m *FileMutation) AddEntityIDs(ids ...int) { + if m.entities == nil { + m.entities = make(map[int]struct{}) + } + for i := range ids { + m.entities[ids[i]] = struct{}{} + } +} + +// ClearEntities clears the "entities" edge to the Entity entity. +func (m *FileMutation) ClearEntities() { + m.clearedentities = true +} + +// EntitiesCleared reports if the "entities" edge to the Entity entity was cleared. +func (m *FileMutation) EntitiesCleared() bool { + return m.clearedentities +} + +// RemoveEntityIDs removes the "entities" edge to the Entity entity by IDs. +func (m *FileMutation) RemoveEntityIDs(ids ...int) { + if m.removedentities == nil { + m.removedentities = make(map[int]struct{}) + } + for i := range ids { + delete(m.entities, ids[i]) + m.removedentities[ids[i]] = struct{}{} + } +} + +// RemovedEntities returns the removed IDs of the "entities" edge to the Entity entity. +func (m *FileMutation) RemovedEntitiesIDs() (ids []int) { + for id := range m.removedentities { + ids = append(ids, id) + } + return +} + +// EntitiesIDs returns the "entities" edge IDs in the mutation. +func (m *FileMutation) EntitiesIDs() (ids []int) { + for id := range m.entities { + ids = append(ids, id) + } + return +} + +// ResetEntities resets all changes to the "entities" edge. +func (m *FileMutation) ResetEntities() { + m.entities = nil + m.clearedentities = false + m.removedentities = nil +} + +// AddShareIDs adds the "shares" edge to the Share entity by ids. +func (m *FileMutation) AddShareIDs(ids ...int) { + if m.shares == nil { + m.shares = make(map[int]struct{}) + } + for i := range ids { + m.shares[ids[i]] = struct{}{} + } +} + +// ClearShares clears the "shares" edge to the Share entity. +func (m *FileMutation) ClearShares() { + m.clearedshares = true +} + +// SharesCleared reports if the "shares" edge to the Share entity was cleared. +func (m *FileMutation) SharesCleared() bool { + return m.clearedshares +} + +// RemoveShareIDs removes the "shares" edge to the Share entity by IDs. +func (m *FileMutation) RemoveShareIDs(ids ...int) { + if m.removedshares == nil { + m.removedshares = make(map[int]struct{}) + } + for i := range ids { + delete(m.shares, ids[i]) + m.removedshares[ids[i]] = struct{}{} + } +} + +// RemovedShares returns the removed IDs of the "shares" edge to the Share entity. +func (m *FileMutation) RemovedSharesIDs() (ids []int) { + for id := range m.removedshares { + ids = append(ids, id) + } + return +} + +// SharesIDs returns the "shares" edge IDs in the mutation. +func (m *FileMutation) SharesIDs() (ids []int) { + for id := range m.shares { + ids = append(ids, id) + } + return +} + +// ResetShares resets all changes to the "shares" edge. +func (m *FileMutation) ResetShares() { + m.shares = nil + m.clearedshares = false + m.removedshares = nil +} + +// AddDirectLinkIDs adds the "direct_links" edge to the DirectLink entity by ids. +func (m *FileMutation) AddDirectLinkIDs(ids ...int) { + if m.direct_links == nil { + m.direct_links = make(map[int]struct{}) + } + for i := range ids { + m.direct_links[ids[i]] = struct{}{} + } +} + +// ClearDirectLinks clears the "direct_links" edge to the DirectLink entity. +func (m *FileMutation) ClearDirectLinks() { + m.cleareddirect_links = true +} + +// DirectLinksCleared reports if the "direct_links" edge to the DirectLink entity was cleared. +func (m *FileMutation) DirectLinksCleared() bool { + return m.cleareddirect_links +} + +// RemoveDirectLinkIDs removes the "direct_links" edge to the DirectLink entity by IDs. +func (m *FileMutation) RemoveDirectLinkIDs(ids ...int) { + if m.removeddirect_links == nil { + m.removeddirect_links = make(map[int]struct{}) + } + for i := range ids { + delete(m.direct_links, ids[i]) + m.removeddirect_links[ids[i]] = struct{}{} + } +} + +// RemovedDirectLinks returns the removed IDs of the "direct_links" edge to the DirectLink entity. +func (m *FileMutation) RemovedDirectLinksIDs() (ids []int) { + for id := range m.removeddirect_links { + ids = append(ids, id) + } + return +} + +// DirectLinksIDs returns the "direct_links" edge IDs in the mutation. +func (m *FileMutation) DirectLinksIDs() (ids []int) { + for id := range m.direct_links { + ids = append(ids, id) + } + return +} + +// ResetDirectLinks resets all changes to the "direct_links" edge. +func (m *FileMutation) ResetDirectLinks() { + m.direct_links = nil + m.cleareddirect_links = false + m.removeddirect_links = nil +} + +// Where appends a list predicates to the FileMutation builder. +func (m *FileMutation) Where(ps ...predicate.File) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the FileMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *FileMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.File, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *FileMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *FileMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (File). +func (m *FileMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *FileMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.created_at != nil { + fields = append(fields, file.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, file.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, file.FieldDeletedAt) + } + if m._type != nil { + fields = append(fields, file.FieldType) + } + if m.name != nil { + fields = append(fields, file.FieldName) + } + if m.owner != nil { + fields = append(fields, file.FieldOwnerID) + } + if m.size != nil { + fields = append(fields, file.FieldSize) + } + if m.primary_entity != nil { + fields = append(fields, file.FieldPrimaryEntity) + } + if m.parent != nil { + fields = append(fields, file.FieldFileChildren) + } + if m.is_symbolic != nil { + fields = append(fields, file.FieldIsSymbolic) + } + if m.props != nil { + fields = append(fields, file.FieldProps) + } + if m.storage_policies != nil { + fields = append(fields, file.FieldStoragePolicyFiles) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *FileMutation) Field(name string) (ent.Value, bool) { + switch name { + case file.FieldCreatedAt: + return m.CreatedAt() + case file.FieldUpdatedAt: + return m.UpdatedAt() + case file.FieldDeletedAt: + return m.DeletedAt() + case file.FieldType: + return m.GetType() + case file.FieldName: + return m.Name() + case file.FieldOwnerID: + return m.OwnerID() + case file.FieldSize: + return m.Size() + case file.FieldPrimaryEntity: + return m.PrimaryEntity() + case file.FieldFileChildren: + return m.FileChildren() + case file.FieldIsSymbolic: + return m.IsSymbolic() + case file.FieldProps: + return m.Props() + case file.FieldStoragePolicyFiles: + return m.StoragePolicyFiles() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *FileMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case file.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case file.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case file.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case file.FieldType: + return m.OldType(ctx) + case file.FieldName: + return m.OldName(ctx) + case file.FieldOwnerID: + return m.OldOwnerID(ctx) + case file.FieldSize: + return m.OldSize(ctx) + case file.FieldPrimaryEntity: + return m.OldPrimaryEntity(ctx) + case file.FieldFileChildren: + return m.OldFileChildren(ctx) + case file.FieldIsSymbolic: + return m.OldIsSymbolic(ctx) + case file.FieldProps: + return m.OldProps(ctx) + case file.FieldStoragePolicyFiles: + return m.OldStoragePolicyFiles(ctx) + } + return nil, fmt.Errorf("unknown File field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *FileMutation) SetField(name string, value ent.Value) error { + switch name { + case file.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case file.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case file.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case file.FieldType: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case file.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case file.FieldOwnerID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOwnerID(v) + return nil + case file.FieldSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSize(v) + return nil + case file.FieldPrimaryEntity: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrimaryEntity(v) + return nil + case file.FieldFileChildren: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFileChildren(v) + return nil + case file.FieldIsSymbolic: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsSymbolic(v) + return nil + case file.FieldProps: + v, ok := value.(*types.FileProps) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProps(v) + return nil + case file.FieldStoragePolicyFiles: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStoragePolicyFiles(v) + return nil + } + return fmt.Errorf("unknown File field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *FileMutation) AddedFields() []string { + var fields []string + if m.add_type != nil { + fields = append(fields, file.FieldType) + } + if m.addsize != nil { + fields = append(fields, file.FieldSize) + } + if m.addprimary_entity != nil { + fields = append(fields, file.FieldPrimaryEntity) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *FileMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case file.FieldType: + return m.AddedType() + case file.FieldSize: + return m.AddedSize() + case file.FieldPrimaryEntity: + return m.AddedPrimaryEntity() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *FileMutation) AddField(name string, value ent.Value) error { + switch name { + case file.FieldType: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddType(v) + return nil + case file.FieldSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSize(v) + return nil + case file.FieldPrimaryEntity: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPrimaryEntity(v) + return nil + } + return fmt.Errorf("unknown File numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *FileMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(file.FieldDeletedAt) { + fields = append(fields, file.FieldDeletedAt) + } + if m.FieldCleared(file.FieldPrimaryEntity) { + fields = append(fields, file.FieldPrimaryEntity) + } + if m.FieldCleared(file.FieldFileChildren) { + fields = append(fields, file.FieldFileChildren) + } + if m.FieldCleared(file.FieldProps) { + fields = append(fields, file.FieldProps) + } + if m.FieldCleared(file.FieldStoragePolicyFiles) { + fields = append(fields, file.FieldStoragePolicyFiles) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *FileMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *FileMutation) ClearField(name string) error { + switch name { + case file.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case file.FieldPrimaryEntity: + m.ClearPrimaryEntity() + return nil + case file.FieldFileChildren: + m.ClearFileChildren() + return nil + case file.FieldProps: + m.ClearProps() + return nil + case file.FieldStoragePolicyFiles: + m.ClearStoragePolicyFiles() + return nil + } + return fmt.Errorf("unknown File nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *FileMutation) ResetField(name string) error { + switch name { + case file.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case file.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case file.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case file.FieldType: + m.ResetType() + return nil + case file.FieldName: + m.ResetName() + return nil + case file.FieldOwnerID: + m.ResetOwnerID() + return nil + case file.FieldSize: + m.ResetSize() + return nil + case file.FieldPrimaryEntity: + m.ResetPrimaryEntity() + return nil + case file.FieldFileChildren: + m.ResetFileChildren() + return nil + case file.FieldIsSymbolic: + m.ResetIsSymbolic() + return nil + case file.FieldProps: + m.ResetProps() + return nil + case file.FieldStoragePolicyFiles: + m.ResetStoragePolicyFiles() + return nil + } + return fmt.Errorf("unknown File field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *FileMutation) AddedEdges() []string { + edges := make([]string, 0, 8) + if m.owner != nil { + edges = append(edges, file.EdgeOwner) + } + if m.storage_policies != nil { + edges = append(edges, file.EdgeStoragePolicies) + } + if m.parent != nil { + edges = append(edges, file.EdgeParent) + } + if m.children != nil { + edges = append(edges, file.EdgeChildren) + } + if m.metadata != nil { + edges = append(edges, file.EdgeMetadata) + } + if m.entities != nil { + edges = append(edges, file.EdgeEntities) + } + if m.shares != nil { + edges = append(edges, file.EdgeShares) + } + if m.direct_links != nil { + edges = append(edges, file.EdgeDirectLinks) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *FileMutation) AddedIDs(name string) []ent.Value { + switch name { + case file.EdgeOwner: + if id := m.owner; id != nil { + return []ent.Value{*id} + } + case file.EdgeStoragePolicies: + if id := m.storage_policies; id != nil { + return []ent.Value{*id} + } + case file.EdgeParent: + if id := m.parent; id != nil { + return []ent.Value{*id} + } + case file.EdgeChildren: + ids := make([]ent.Value, 0, len(m.children)) + for id := range m.children { + ids = append(ids, id) + } + return ids + case file.EdgeMetadata: + ids := make([]ent.Value, 0, len(m.metadata)) + for id := range m.metadata { + ids = append(ids, id) + } + return ids + case file.EdgeEntities: + ids := make([]ent.Value, 0, len(m.entities)) + for id := range m.entities { + ids = append(ids, id) + } + return ids + case file.EdgeShares: + ids := make([]ent.Value, 0, len(m.shares)) + for id := range m.shares { + ids = append(ids, id) + } + return ids + case file.EdgeDirectLinks: + ids := make([]ent.Value, 0, len(m.direct_links)) + for id := range m.direct_links { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *FileMutation) RemovedEdges() []string { + edges := make([]string, 0, 8) + if m.removedchildren != nil { + edges = append(edges, file.EdgeChildren) + } + if m.removedmetadata != nil { + edges = append(edges, file.EdgeMetadata) + } + if m.removedentities != nil { + edges = append(edges, file.EdgeEntities) + } + if m.removedshares != nil { + edges = append(edges, file.EdgeShares) + } + if m.removeddirect_links != nil { + edges = append(edges, file.EdgeDirectLinks) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *FileMutation) RemovedIDs(name string) []ent.Value { + switch name { + case file.EdgeChildren: + ids := make([]ent.Value, 0, len(m.removedchildren)) + for id := range m.removedchildren { + ids = append(ids, id) + } + return ids + case file.EdgeMetadata: + ids := make([]ent.Value, 0, len(m.removedmetadata)) + for id := range m.removedmetadata { + ids = append(ids, id) + } + return ids + case file.EdgeEntities: + ids := make([]ent.Value, 0, len(m.removedentities)) + for id := range m.removedentities { + ids = append(ids, id) + } + return ids + case file.EdgeShares: + ids := make([]ent.Value, 0, len(m.removedshares)) + for id := range m.removedshares { + ids = append(ids, id) + } + return ids + case file.EdgeDirectLinks: + ids := make([]ent.Value, 0, len(m.removeddirect_links)) + for id := range m.removeddirect_links { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *FileMutation) ClearedEdges() []string { + edges := make([]string, 0, 8) + if m.clearedowner { + edges = append(edges, file.EdgeOwner) + } + if m.clearedstorage_policies { + edges = append(edges, file.EdgeStoragePolicies) + } + if m.clearedparent { + edges = append(edges, file.EdgeParent) + } + if m.clearedchildren { + edges = append(edges, file.EdgeChildren) + } + if m.clearedmetadata { + edges = append(edges, file.EdgeMetadata) + } + if m.clearedentities { + edges = append(edges, file.EdgeEntities) + } + if m.clearedshares { + edges = append(edges, file.EdgeShares) + } + if m.cleareddirect_links { + edges = append(edges, file.EdgeDirectLinks) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *FileMutation) EdgeCleared(name string) bool { + switch name { + case file.EdgeOwner: + return m.clearedowner + case file.EdgeStoragePolicies: + return m.clearedstorage_policies + case file.EdgeParent: + return m.clearedparent + case file.EdgeChildren: + return m.clearedchildren + case file.EdgeMetadata: + return m.clearedmetadata + case file.EdgeEntities: + return m.clearedentities + case file.EdgeShares: + return m.clearedshares + case file.EdgeDirectLinks: + return m.cleareddirect_links + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *FileMutation) ClearEdge(name string) error { + switch name { + case file.EdgeOwner: + m.ClearOwner() + return nil + case file.EdgeStoragePolicies: + m.ClearStoragePolicies() + return nil + case file.EdgeParent: + m.ClearParent() + return nil + } + return fmt.Errorf("unknown File unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *FileMutation) ResetEdge(name string) error { + switch name { + case file.EdgeOwner: + m.ResetOwner() + return nil + case file.EdgeStoragePolicies: + m.ResetStoragePolicies() + return nil + case file.EdgeParent: + m.ResetParent() + return nil + case file.EdgeChildren: + m.ResetChildren() + return nil + case file.EdgeMetadata: + m.ResetMetadata() + return nil + case file.EdgeEntities: + m.ResetEntities() + return nil + case file.EdgeShares: + m.ResetShares() + return nil + case file.EdgeDirectLinks: + m.ResetDirectLinks() + return nil + } + return fmt.Errorf("unknown File edge %s", name) +} + +// GroupMutation represents an operation that mutates the Group nodes in the graph. +type GroupMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + max_storage *int64 + addmax_storage *int64 + speed_limit *int + addspeed_limit *int + permissions **boolset.BooleanSet + settings **types.GroupSetting + clearedFields map[string]struct{} + users map[int]struct{} + removedusers map[int]struct{} + clearedusers bool + storage_policies *int + clearedstorage_policies bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group +} + +var _ ent.Mutation = (*GroupMutation)(nil) + +// groupOption allows management of the mutation configuration using functional options. +type groupOption func(*GroupMutation) + +// newGroupMutation creates new mutation for the Group entity. +func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { + m := &GroupMutation{ + config: c, + op: op, + typ: TypeGroup, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withGroupID sets the ID field of the mutation. +func withGroupID(id int) groupOption { + return func(m *GroupMutation) { + var ( + err error + once sync.Once + value *Group + ) + m.oldValue = func(ctx context.Context) (*Group, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Group.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withGroup sets the old Group of the mutation. +func withGroup(node *Group) groupOption { + return func(m *GroupMutation) { + m.oldValue = func(context.Context) (*Group, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m GroupMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m GroupMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *GroupMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *GroupMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *GroupMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *GroupMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *GroupMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *GroupMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *GroupMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *GroupMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[group.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *GroupMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[group.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *GroupMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, group.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *GroupMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *GroupMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *GroupMutation) ResetName() { + m.name = nil +} + +// SetMaxStorage sets the "max_storage" field. +func (m *GroupMutation) SetMaxStorage(i int64) { + m.max_storage = &i + m.addmax_storage = nil +} + +// MaxStorage returns the value of the "max_storage" field in the mutation. +func (m *GroupMutation) MaxStorage() (r int64, exists bool) { + v := m.max_storage + if v == nil { + return + } + return *v, true +} + +// OldMaxStorage returns the old "max_storage" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMaxStorage(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMaxStorage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMaxStorage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMaxStorage: %w", err) + } + return oldValue.MaxStorage, nil +} + +// AddMaxStorage adds i to the "max_storage" field. +func (m *GroupMutation) AddMaxStorage(i int64) { + if m.addmax_storage != nil { + *m.addmax_storage += i + } else { + m.addmax_storage = &i + } +} + +// AddedMaxStorage returns the value that was added to the "max_storage" field in this mutation. +func (m *GroupMutation) AddedMaxStorage() (r int64, exists bool) { + v := m.addmax_storage + if v == nil { + return + } + return *v, true +} + +// ClearMaxStorage clears the value of the "max_storage" field. +func (m *GroupMutation) ClearMaxStorage() { + m.max_storage = nil + m.addmax_storage = nil + m.clearedFields[group.FieldMaxStorage] = struct{}{} +} + +// MaxStorageCleared returns if the "max_storage" field was cleared in this mutation. +func (m *GroupMutation) MaxStorageCleared() bool { + _, ok := m.clearedFields[group.FieldMaxStorage] + return ok +} + +// ResetMaxStorage resets all changes to the "max_storage" field. +func (m *GroupMutation) ResetMaxStorage() { + m.max_storage = nil + m.addmax_storage = nil + delete(m.clearedFields, group.FieldMaxStorage) +} + +// SetSpeedLimit sets the "speed_limit" field. +func (m *GroupMutation) SetSpeedLimit(i int) { + m.speed_limit = &i + m.addspeed_limit = nil +} + +// SpeedLimit returns the value of the "speed_limit" field in the mutation. +func (m *GroupMutation) SpeedLimit() (r int, exists bool) { + v := m.speed_limit + if v == nil { + return + } + return *v, true +} + +// OldSpeedLimit returns the old "speed_limit" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSpeedLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSpeedLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSpeedLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSpeedLimit: %w", err) + } + return oldValue.SpeedLimit, nil +} + +// AddSpeedLimit adds i to the "speed_limit" field. +func (m *GroupMutation) AddSpeedLimit(i int) { + if m.addspeed_limit != nil { + *m.addspeed_limit += i + } else { + m.addspeed_limit = &i + } +} + +// AddedSpeedLimit returns the value that was added to the "speed_limit" field in this mutation. +func (m *GroupMutation) AddedSpeedLimit() (r int, exists bool) { + v := m.addspeed_limit + if v == nil { + return + } + return *v, true +} + +// ClearSpeedLimit clears the value of the "speed_limit" field. +func (m *GroupMutation) ClearSpeedLimit() { + m.speed_limit = nil + m.addspeed_limit = nil + m.clearedFields[group.FieldSpeedLimit] = struct{}{} +} + +// SpeedLimitCleared returns if the "speed_limit" field was cleared in this mutation. +func (m *GroupMutation) SpeedLimitCleared() bool { + _, ok := m.clearedFields[group.FieldSpeedLimit] + return ok +} + +// ResetSpeedLimit resets all changes to the "speed_limit" field. +func (m *GroupMutation) ResetSpeedLimit() { + m.speed_limit = nil + m.addspeed_limit = nil + delete(m.clearedFields, group.FieldSpeedLimit) +} + +// SetPermissions sets the "permissions" field. +func (m *GroupMutation) SetPermissions(bs *boolset.BooleanSet) { + m.permissions = &bs +} + +// Permissions returns the value of the "permissions" field in the mutation. +func (m *GroupMutation) Permissions() (r *boolset.BooleanSet, exists bool) { + v := m.permissions + if v == nil { + return + } + return *v, true +} + +// OldPermissions returns the old "permissions" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldPermissions(ctx context.Context) (v *boolset.BooleanSet, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPermissions is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPermissions requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPermissions: %w", err) + } + return oldValue.Permissions, nil +} + +// ResetPermissions resets all changes to the "permissions" field. +func (m *GroupMutation) ResetPermissions() { + m.permissions = nil +} + +// SetSettings sets the "settings" field. +func (m *GroupMutation) SetSettings(ts *types.GroupSetting) { + m.settings = &ts +} + +// Settings returns the value of the "settings" field in the mutation. +func (m *GroupMutation) Settings() (r *types.GroupSetting, exists bool) { + v := m.settings + if v == nil { + return + } + return *v, true +} + +// OldSettings returns the old "settings" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSettings(ctx context.Context) (v *types.GroupSetting, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSettings is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSettings requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSettings: %w", err) + } + return oldValue.Settings, nil +} + +// ClearSettings clears the value of the "settings" field. +func (m *GroupMutation) ClearSettings() { + m.settings = nil + m.clearedFields[group.FieldSettings] = struct{}{} +} + +// SettingsCleared returns if the "settings" field was cleared in this mutation. +func (m *GroupMutation) SettingsCleared() bool { + _, ok := m.clearedFields[group.FieldSettings] + return ok +} + +// ResetSettings resets all changes to the "settings" field. +func (m *GroupMutation) ResetSettings() { + m.settings = nil + delete(m.clearedFields, group.FieldSettings) +} + +// SetStoragePolicyID sets the "storage_policy_id" field. +func (m *GroupMutation) SetStoragePolicyID(i int) { + m.storage_policies = &i +} + +// StoragePolicyID returns the value of the "storage_policy_id" field in the mutation. +func (m *GroupMutation) StoragePolicyID() (r int, exists bool) { + v := m.storage_policies + if v == nil { + return + } + return *v, true +} + +// OldStoragePolicyID returns the old "storage_policy_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldStoragePolicyID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStoragePolicyID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStoragePolicyID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStoragePolicyID: %w", err) + } + return oldValue.StoragePolicyID, nil +} + +// ClearStoragePolicyID clears the value of the "storage_policy_id" field. +func (m *GroupMutation) ClearStoragePolicyID() { + m.storage_policies = nil + m.clearedFields[group.FieldStoragePolicyID] = struct{}{} +} + +// StoragePolicyIDCleared returns if the "storage_policy_id" field was cleared in this mutation. +func (m *GroupMutation) StoragePolicyIDCleared() bool { + _, ok := m.clearedFields[group.FieldStoragePolicyID] + return ok +} + +// ResetStoragePolicyID resets all changes to the "storage_policy_id" field. +func (m *GroupMutation) ResetStoragePolicyID() { + m.storage_policies = nil + delete(m.clearedFields, group.FieldStoragePolicyID) +} + +// AddUserIDs adds the "users" edge to the User entity by ids. +func (m *GroupMutation) AddUserIDs(ids ...int) { + if m.users == nil { + m.users = make(map[int]struct{}) + } + for i := range ids { + m.users[ids[i]] = struct{}{} + } +} + +// ClearUsers clears the "users" edge to the User entity. +func (m *GroupMutation) ClearUsers() { + m.clearedusers = true +} + +// UsersCleared reports if the "users" edge to the User entity was cleared. +func (m *GroupMutation) UsersCleared() bool { + return m.clearedusers +} + +// RemoveUserIDs removes the "users" edge to the User entity by IDs. +func (m *GroupMutation) RemoveUserIDs(ids ...int) { + if m.removedusers == nil { + m.removedusers = make(map[int]struct{}) + } + for i := range ids { + delete(m.users, ids[i]) + m.removedusers[ids[i]] = struct{}{} + } +} + +// RemovedUsers returns the removed IDs of the "users" edge to the User entity. +func (m *GroupMutation) RemovedUsersIDs() (ids []int) { + for id := range m.removedusers { + ids = append(ids, id) + } + return +} + +// UsersIDs returns the "users" edge IDs in the mutation. +func (m *GroupMutation) UsersIDs() (ids []int) { + for id := range m.users { + ids = append(ids, id) + } + return +} + +// ResetUsers resets all changes to the "users" edge. +func (m *GroupMutation) ResetUsers() { + m.users = nil + m.clearedusers = false + m.removedusers = nil +} + +// SetStoragePoliciesID sets the "storage_policies" edge to the StoragePolicy entity by id. +func (m *GroupMutation) SetStoragePoliciesID(id int) { + m.storage_policies = &id +} + +// ClearStoragePolicies clears the "storage_policies" edge to the StoragePolicy entity. +func (m *GroupMutation) ClearStoragePolicies() { + m.clearedstorage_policies = true + m.clearedFields[group.FieldStoragePolicyID] = struct{}{} +} + +// StoragePoliciesCleared reports if the "storage_policies" edge to the StoragePolicy entity was cleared. +func (m *GroupMutation) StoragePoliciesCleared() bool { + return m.StoragePolicyIDCleared() || m.clearedstorage_policies +} + +// StoragePoliciesID returns the "storage_policies" edge ID in the mutation. +func (m *GroupMutation) StoragePoliciesID() (id int, exists bool) { + if m.storage_policies != nil { + return *m.storage_policies, true + } + return +} + +// StoragePoliciesIDs returns the "storage_policies" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// StoragePoliciesID instead. It exists only for internal usage by the builders. +func (m *GroupMutation) StoragePoliciesIDs() (ids []int) { + if id := m.storage_policies; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetStoragePolicies resets all changes to the "storage_policies" edge. +func (m *GroupMutation) ResetStoragePolicies() { + m.storage_policies = nil + m.clearedstorage_policies = false +} + +// Where appends a list predicates to the GroupMutation builder. +func (m *GroupMutation) Where(ps ...predicate.Group) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Group, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *GroupMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *GroupMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Group). +func (m *GroupMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *GroupMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, group.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, group.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, group.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, group.FieldName) + } + if m.max_storage != nil { + fields = append(fields, group.FieldMaxStorage) + } + if m.speed_limit != nil { + fields = append(fields, group.FieldSpeedLimit) + } + if m.permissions != nil { + fields = append(fields, group.FieldPermissions) + } + if m.settings != nil { + fields = append(fields, group.FieldSettings) + } + if m.storage_policies != nil { + fields = append(fields, group.FieldStoragePolicyID) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *GroupMutation) Field(name string) (ent.Value, bool) { + switch name { + case group.FieldCreatedAt: + return m.CreatedAt() + case group.FieldUpdatedAt: + return m.UpdatedAt() + case group.FieldDeletedAt: + return m.DeletedAt() + case group.FieldName: + return m.Name() + case group.FieldMaxStorage: + return m.MaxStorage() + case group.FieldSpeedLimit: + return m.SpeedLimit() + case group.FieldPermissions: + return m.Permissions() + case group.FieldSettings: + return m.Settings() + case group.FieldStoragePolicyID: + return m.StoragePolicyID() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case group.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case group.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case group.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case group.FieldName: + return m.OldName(ctx) + case group.FieldMaxStorage: + return m.OldMaxStorage(ctx) + case group.FieldSpeedLimit: + return m.OldSpeedLimit(ctx) + case group.FieldPermissions: + return m.OldPermissions(ctx) + case group.FieldSettings: + return m.OldSettings(ctx) + case group.FieldStoragePolicyID: + return m.OldStoragePolicyID(ctx) + } + return nil, fmt.Errorf("unknown Group field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GroupMutation) SetField(name string, value ent.Value) error { + switch name { + case group.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case group.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case group.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case group.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case group.FieldMaxStorage: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMaxStorage(v) + return nil + case group.FieldSpeedLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSpeedLimit(v) + return nil + case group.FieldPermissions: + v, ok := value.(*boolset.BooleanSet) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPermissions(v) + return nil + case group.FieldSettings: + v, ok := value.(*types.GroupSetting) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSettings(v) + return nil + case group.FieldStoragePolicyID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStoragePolicyID(v) + return nil + } + return fmt.Errorf("unknown Group field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *GroupMutation) AddedFields() []string { + var fields []string + if m.addmax_storage != nil { + fields = append(fields, group.FieldMaxStorage) + } + if m.addspeed_limit != nil { + fields = append(fields, group.FieldSpeedLimit) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case group.FieldMaxStorage: + return m.AddedMaxStorage() + case group.FieldSpeedLimit: + return m.AddedSpeedLimit() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GroupMutation) AddField(name string, value ent.Value) error { + switch name { + case group.FieldMaxStorage: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMaxStorage(v) + return nil + case group.FieldSpeedLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSpeedLimit(v) + return nil + } + return fmt.Errorf("unknown Group numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *GroupMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(group.FieldDeletedAt) { + fields = append(fields, group.FieldDeletedAt) + } + if m.FieldCleared(group.FieldMaxStorage) { + fields = append(fields, group.FieldMaxStorage) + } + if m.FieldCleared(group.FieldSpeedLimit) { + fields = append(fields, group.FieldSpeedLimit) + } + if m.FieldCleared(group.FieldSettings) { + fields = append(fields, group.FieldSettings) + } + if m.FieldCleared(group.FieldStoragePolicyID) { + fields = append(fields, group.FieldStoragePolicyID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *GroupMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *GroupMutation) ClearField(name string) error { + switch name { + case group.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case group.FieldMaxStorage: + m.ClearMaxStorage() + return nil + case group.FieldSpeedLimit: + m.ClearSpeedLimit() + return nil + case group.FieldSettings: + m.ClearSettings() + return nil + case group.FieldStoragePolicyID: + m.ClearStoragePolicyID() + return nil + } + return fmt.Errorf("unknown Group nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *GroupMutation) ResetField(name string) error { + switch name { + case group.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case group.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case group.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case group.FieldName: + m.ResetName() + return nil + case group.FieldMaxStorage: + m.ResetMaxStorage() + return nil + case group.FieldSpeedLimit: + m.ResetSpeedLimit() + return nil + case group.FieldPermissions: + m.ResetPermissions() + return nil + case group.FieldSettings: + m.ResetSettings() + return nil + case group.FieldStoragePolicyID: + m.ResetStoragePolicyID() + return nil + } + return fmt.Errorf("unknown Group field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *GroupMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.users != nil { + edges = append(edges, group.EdgeUsers) + } + if m.storage_policies != nil { + edges = append(edges, group.EdgeStoragePolicies) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeUsers: + ids := make([]ent.Value, 0, len(m.users)) + for id := range m.users { + ids = append(ids, id) + } + return ids + case group.EdgeStoragePolicies: + if id := m.storage_policies; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *GroupMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + if m.removedusers != nil { + edges = append(edges, group.EdgeUsers) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeUsers: + ids := make([]ent.Value, 0, len(m.removedusers)) + for id := range m.removedusers { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *GroupMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedusers { + edges = append(edges, group.EdgeUsers) + } + if m.clearedstorage_policies { + edges = append(edges, group.EdgeStoragePolicies) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeUsers: + return m.clearedusers + case group.EdgeStoragePolicies: + return m.clearedstorage_policies + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *GroupMutation) ClearEdge(name string) error { + switch name { + case group.EdgeStoragePolicies: + m.ClearStoragePolicies() + return nil + } + return fmt.Errorf("unknown Group unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeUsers: + m.ResetUsers() + return nil + case group.EdgeStoragePolicies: + m.ResetStoragePolicies() + return nil + } + return fmt.Errorf("unknown Group edge %s", name) +} + +// MetadataMutation represents an operation that mutates the Metadata nodes in the graph. +type MetadataMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + value *string + is_public *bool + clearedFields map[string]struct{} + file *int + clearedfile bool + done bool + oldValue func(context.Context) (*Metadata, error) + predicates []predicate.Metadata +} + +var _ ent.Mutation = (*MetadataMutation)(nil) + +// metadataOption allows management of the mutation configuration using functional options. +type metadataOption func(*MetadataMutation) + +// newMetadataMutation creates new mutation for the Metadata entity. +func newMetadataMutation(c config, op Op, opts ...metadataOption) *MetadataMutation { + m := &MetadataMutation{ + config: c, + op: op, + typ: TypeMetadata, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withMetadataID sets the ID field of the mutation. +func withMetadataID(id int) metadataOption { + return func(m *MetadataMutation) { + var ( + err error + once sync.Once + value *Metadata + ) + m.oldValue = func(ctx context.Context) (*Metadata, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Metadata.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withMetadata sets the old Metadata of the mutation. +func withMetadata(node *Metadata) metadataOption { + return func(m *MetadataMutation) { + m.oldValue = func(context.Context) (*Metadata, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MetadataMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MetadataMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MetadataMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MetadataMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Metadata.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *MetadataMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *MetadataMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *MetadataMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *MetadataMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *MetadataMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *MetadataMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *MetadataMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *MetadataMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *MetadataMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[metadata.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *MetadataMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[metadata.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *MetadataMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, metadata.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *MetadataMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *MetadataMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *MetadataMutation) ResetName() { + m.name = nil +} + +// SetValue sets the "value" field. +func (m *MetadataMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *MetadataMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ResetValue resets all changes to the "value" field. +func (m *MetadataMutation) ResetValue() { + m.value = nil +} + +// SetFileID sets the "file_id" field. +func (m *MetadataMutation) SetFileID(i int) { + m.file = &i +} + +// FileID returns the value of the "file_id" field in the mutation. +func (m *MetadataMutation) FileID() (r int, exists bool) { + v := m.file + if v == nil { + return + } + return *v, true +} + +// OldFileID returns the old "file_id" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldFileID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFileID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFileID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFileID: %w", err) + } + return oldValue.FileID, nil +} + +// ResetFileID resets all changes to the "file_id" field. +func (m *MetadataMutation) ResetFileID() { + m.file = nil +} + +// SetIsPublic sets the "is_public" field. +func (m *MetadataMutation) SetIsPublic(b bool) { + m.is_public = &b +} + +// IsPublic returns the value of the "is_public" field in the mutation. +func (m *MetadataMutation) IsPublic() (r bool, exists bool) { + v := m.is_public + if v == nil { + return + } + return *v, true +} + +// OldIsPublic returns the old "is_public" field's value of the Metadata entity. +// If the Metadata object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MetadataMutation) OldIsPublic(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsPublic is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsPublic requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsPublic: %w", err) + } + return oldValue.IsPublic, nil +} + +// ResetIsPublic resets all changes to the "is_public" field. +func (m *MetadataMutation) ResetIsPublic() { + m.is_public = nil +} + +// ClearFile clears the "file" edge to the File entity. +func (m *MetadataMutation) ClearFile() { + m.clearedfile = true + m.clearedFields[metadata.FieldFileID] = struct{}{} +} + +// FileCleared reports if the "file" edge to the File entity was cleared. +func (m *MetadataMutation) FileCleared() bool { + return m.clearedfile +} + +// FileIDs returns the "file" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// FileID instead. It exists only for internal usage by the builders. +func (m *MetadataMutation) FileIDs() (ids []int) { + if id := m.file; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetFile resets all changes to the "file" edge. +func (m *MetadataMutation) ResetFile() { + m.file = nil + m.clearedfile = false +} + +// Where appends a list predicates to the MetadataMutation builder. +func (m *MetadataMutation) Where(ps ...predicate.Metadata) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the MetadataMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MetadataMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Metadata, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *MetadataMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *MetadataMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Metadata). +func (m *MetadataMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *MetadataMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, metadata.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, metadata.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, metadata.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, metadata.FieldName) + } + if m.value != nil { + fields = append(fields, metadata.FieldValue) + } + if m.file != nil { + fields = append(fields, metadata.FieldFileID) + } + if m.is_public != nil { + fields = append(fields, metadata.FieldIsPublic) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *MetadataMutation) Field(name string) (ent.Value, bool) { + switch name { + case metadata.FieldCreatedAt: + return m.CreatedAt() + case metadata.FieldUpdatedAt: + return m.UpdatedAt() + case metadata.FieldDeletedAt: + return m.DeletedAt() + case metadata.FieldName: + return m.Name() + case metadata.FieldValue: + return m.Value() + case metadata.FieldFileID: + return m.FileID() + case metadata.FieldIsPublic: + return m.IsPublic() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *MetadataMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case metadata.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case metadata.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case metadata.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case metadata.FieldName: + return m.OldName(ctx) + case metadata.FieldValue: + return m.OldValue(ctx) + case metadata.FieldFileID: + return m.OldFileID(ctx) + case metadata.FieldIsPublic: + return m.OldIsPublic(ctx) + } + return nil, fmt.Errorf("unknown Metadata field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MetadataMutation) SetField(name string, value ent.Value) error { + switch name { + case metadata.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case metadata.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case metadata.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case metadata.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case metadata.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + case metadata.FieldFileID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFileID(v) + return nil + case metadata.FieldIsPublic: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsPublic(v) + return nil + } + return fmt.Errorf("unknown Metadata field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *MetadataMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *MetadataMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MetadataMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Metadata numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *MetadataMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(metadata.FieldDeletedAt) { + fields = append(fields, metadata.FieldDeletedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *MetadataMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *MetadataMutation) ClearField(name string) error { + switch name { + case metadata.FieldDeletedAt: + m.ClearDeletedAt() + return nil + } + return fmt.Errorf("unknown Metadata nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *MetadataMutation) ResetField(name string) error { + switch name { + case metadata.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case metadata.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case metadata.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case metadata.FieldName: + m.ResetName() + return nil + case metadata.FieldValue: + m.ResetValue() + return nil + case metadata.FieldFileID: + m.ResetFileID() + return nil + case metadata.FieldIsPublic: + m.ResetIsPublic() + return nil + } + return fmt.Errorf("unknown Metadata field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *MetadataMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.file != nil { + edges = append(edges, metadata.EdgeFile) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *MetadataMutation) AddedIDs(name string) []ent.Value { + switch name { + case metadata.EdgeFile: + if id := m.file; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *MetadataMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *MetadataMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *MetadataMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedfile { + edges = append(edges, metadata.EdgeFile) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *MetadataMutation) EdgeCleared(name string) bool { + switch name { + case metadata.EdgeFile: + return m.clearedfile + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *MetadataMutation) ClearEdge(name string) error { + switch name { + case metadata.EdgeFile: + m.ClearFile() + return nil + } + return fmt.Errorf("unknown Metadata unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *MetadataMutation) ResetEdge(name string) error { + switch name { + case metadata.EdgeFile: + m.ResetFile() + return nil + } + return fmt.Errorf("unknown Metadata edge %s", name) +} + +// NodeMutation represents an operation that mutates the Node nodes in the graph. +type NodeMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + status *node.Status + name *string + _type *node.Type + server *string + slave_key *string + capabilities **boolset.BooleanSet + settings **types.NodeSetting + weight *int + addweight *int + clearedFields map[string]struct{} + storage_policy map[int]struct{} + removedstorage_policy map[int]struct{} + clearedstorage_policy bool + done bool + oldValue func(context.Context) (*Node, error) + predicates []predicate.Node +} + +var _ ent.Mutation = (*NodeMutation)(nil) + +// nodeOption allows management of the mutation configuration using functional options. +type nodeOption func(*NodeMutation) + +// newNodeMutation creates new mutation for the Node entity. +func newNodeMutation(c config, op Op, opts ...nodeOption) *NodeMutation { + m := &NodeMutation{ + config: c, + op: op, + typ: TypeNode, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withNodeID sets the ID field of the mutation. +func withNodeID(id int) nodeOption { + return func(m *NodeMutation) { + var ( + err error + once sync.Once + value *Node + ) + m.oldValue = func(ctx context.Context) (*Node, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Node.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withNode sets the old Node of the mutation. +func withNode(node *Node) nodeOption { + return func(m *NodeMutation) { + m.oldValue = func(context.Context) (*Node, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m NodeMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m NodeMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *NodeMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *NodeMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Node.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *NodeMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *NodeMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *NodeMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *NodeMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *NodeMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *NodeMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *NodeMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *NodeMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *NodeMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[node.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *NodeMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[node.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *NodeMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, node.FieldDeletedAt) +} + +// SetStatus sets the "status" field. +func (m *NodeMutation) SetStatus(n node.Status) { + m.status = &n +} + +// Status returns the value of the "status" field in the mutation. +func (m *NodeMutation) Status() (r node.Status, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldStatus(ctx context.Context) (v node.Status, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *NodeMutation) ResetStatus() { + m.status = nil +} + +// SetName sets the "name" field. +func (m *NodeMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *NodeMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *NodeMutation) ResetName() { + m.name = nil +} + +// SetType sets the "type" field. +func (m *NodeMutation) SetType(n node.Type) { + m._type = &n +} + +// GetType returns the value of the "type" field in the mutation. +func (m *NodeMutation) GetType() (r node.Type, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldType(ctx context.Context) (v node.Type, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// ResetType resets all changes to the "type" field. +func (m *NodeMutation) ResetType() { + m._type = nil +} + +// SetServer sets the "server" field. +func (m *NodeMutation) SetServer(s string) { + m.server = &s +} + +// Server returns the value of the "server" field in the mutation. +func (m *NodeMutation) Server() (r string, exists bool) { + v := m.server + if v == nil { + return + } + return *v, true +} + +// OldServer returns the old "server" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldServer(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldServer is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldServer requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldServer: %w", err) + } + return oldValue.Server, nil +} + +// ClearServer clears the value of the "server" field. +func (m *NodeMutation) ClearServer() { + m.server = nil + m.clearedFields[node.FieldServer] = struct{}{} +} + +// ServerCleared returns if the "server" field was cleared in this mutation. +func (m *NodeMutation) ServerCleared() bool { + _, ok := m.clearedFields[node.FieldServer] + return ok +} + +// ResetServer resets all changes to the "server" field. +func (m *NodeMutation) ResetServer() { + m.server = nil + delete(m.clearedFields, node.FieldServer) +} + +// SetSlaveKey sets the "slave_key" field. +func (m *NodeMutation) SetSlaveKey(s string) { + m.slave_key = &s +} + +// SlaveKey returns the value of the "slave_key" field in the mutation. +func (m *NodeMutation) SlaveKey() (r string, exists bool) { + v := m.slave_key + if v == nil { + return + } + return *v, true +} + +// OldSlaveKey returns the old "slave_key" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldSlaveKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSlaveKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSlaveKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSlaveKey: %w", err) + } + return oldValue.SlaveKey, nil +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (m *NodeMutation) ClearSlaveKey() { + m.slave_key = nil + m.clearedFields[node.FieldSlaveKey] = struct{}{} +} + +// SlaveKeyCleared returns if the "slave_key" field was cleared in this mutation. +func (m *NodeMutation) SlaveKeyCleared() bool { + _, ok := m.clearedFields[node.FieldSlaveKey] + return ok +} + +// ResetSlaveKey resets all changes to the "slave_key" field. +func (m *NodeMutation) ResetSlaveKey() { + m.slave_key = nil + delete(m.clearedFields, node.FieldSlaveKey) +} + +// SetCapabilities sets the "capabilities" field. +func (m *NodeMutation) SetCapabilities(bs *boolset.BooleanSet) { + m.capabilities = &bs +} + +// Capabilities returns the value of the "capabilities" field in the mutation. +func (m *NodeMutation) Capabilities() (r *boolset.BooleanSet, exists bool) { + v := m.capabilities + if v == nil { + return + } + return *v, true +} + +// OldCapabilities returns the old "capabilities" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldCapabilities(ctx context.Context) (v *boolset.BooleanSet, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCapabilities is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCapabilities requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCapabilities: %w", err) + } + return oldValue.Capabilities, nil +} + +// ResetCapabilities resets all changes to the "capabilities" field. +func (m *NodeMutation) ResetCapabilities() { + m.capabilities = nil +} + +// SetSettings sets the "settings" field. +func (m *NodeMutation) SetSettings(ts *types.NodeSetting) { + m.settings = &ts +} + +// Settings returns the value of the "settings" field in the mutation. +func (m *NodeMutation) Settings() (r *types.NodeSetting, exists bool) { + v := m.settings + if v == nil { + return + } + return *v, true +} + +// OldSettings returns the old "settings" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldSettings(ctx context.Context) (v *types.NodeSetting, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSettings is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSettings requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSettings: %w", err) + } + return oldValue.Settings, nil +} + +// ClearSettings clears the value of the "settings" field. +func (m *NodeMutation) ClearSettings() { + m.settings = nil + m.clearedFields[node.FieldSettings] = struct{}{} +} + +// SettingsCleared returns if the "settings" field was cleared in this mutation. +func (m *NodeMutation) SettingsCleared() bool { + _, ok := m.clearedFields[node.FieldSettings] + return ok +} + +// ResetSettings resets all changes to the "settings" field. +func (m *NodeMutation) ResetSettings() { + m.settings = nil + delete(m.clearedFields, node.FieldSettings) +} + +// SetWeight sets the "weight" field. +func (m *NodeMutation) SetWeight(i int) { + m.weight = &i + m.addweight = nil +} + +// Weight returns the value of the "weight" field in the mutation. +func (m *NodeMutation) Weight() (r int, exists bool) { + v := m.weight + if v == nil { + return + } + return *v, true +} + +// OldWeight returns the old "weight" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldWeight(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeight is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeight requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeight: %w", err) + } + return oldValue.Weight, nil +} + +// AddWeight adds i to the "weight" field. +func (m *NodeMutation) AddWeight(i int) { + if m.addweight != nil { + *m.addweight += i + } else { + m.addweight = &i + } +} + +// AddedWeight returns the value that was added to the "weight" field in this mutation. +func (m *NodeMutation) AddedWeight() (r int, exists bool) { + v := m.addweight + if v == nil { + return + } + return *v, true +} + +// ResetWeight resets all changes to the "weight" field. +func (m *NodeMutation) ResetWeight() { + m.weight = nil + m.addweight = nil +} + +// AddStoragePolicyIDs adds the "storage_policy" edge to the StoragePolicy entity by ids. +func (m *NodeMutation) AddStoragePolicyIDs(ids ...int) { + if m.storage_policy == nil { + m.storage_policy = make(map[int]struct{}) + } + for i := range ids { + m.storage_policy[ids[i]] = struct{}{} + } +} + +// ClearStoragePolicy clears the "storage_policy" edge to the StoragePolicy entity. +func (m *NodeMutation) ClearStoragePolicy() { + m.clearedstorage_policy = true +} + +// StoragePolicyCleared reports if the "storage_policy" edge to the StoragePolicy entity was cleared. +func (m *NodeMutation) StoragePolicyCleared() bool { + return m.clearedstorage_policy +} + +// RemoveStoragePolicyIDs removes the "storage_policy" edge to the StoragePolicy entity by IDs. +func (m *NodeMutation) RemoveStoragePolicyIDs(ids ...int) { + if m.removedstorage_policy == nil { + m.removedstorage_policy = make(map[int]struct{}) + } + for i := range ids { + delete(m.storage_policy, ids[i]) + m.removedstorage_policy[ids[i]] = struct{}{} + } +} + +// RemovedStoragePolicy returns the removed IDs of the "storage_policy" edge to the StoragePolicy entity. +func (m *NodeMutation) RemovedStoragePolicyIDs() (ids []int) { + for id := range m.removedstorage_policy { + ids = append(ids, id) + } + return +} + +// StoragePolicyIDs returns the "storage_policy" edge IDs in the mutation. +func (m *NodeMutation) StoragePolicyIDs() (ids []int) { + for id := range m.storage_policy { + ids = append(ids, id) + } + return +} + +// ResetStoragePolicy resets all changes to the "storage_policy" edge. +func (m *NodeMutation) ResetStoragePolicy() { + m.storage_policy = nil + m.clearedstorage_policy = false + m.removedstorage_policy = nil +} + +// Where appends a list predicates to the NodeMutation builder. +func (m *NodeMutation) Where(ps ...predicate.Node) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the NodeMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *NodeMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Node, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *NodeMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *NodeMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Node). +func (m *NodeMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *NodeMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, node.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, node.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, node.FieldDeletedAt) + } + if m.status != nil { + fields = append(fields, node.FieldStatus) + } + if m.name != nil { + fields = append(fields, node.FieldName) + } + if m._type != nil { + fields = append(fields, node.FieldType) + } + if m.server != nil { + fields = append(fields, node.FieldServer) + } + if m.slave_key != nil { + fields = append(fields, node.FieldSlaveKey) + } + if m.capabilities != nil { + fields = append(fields, node.FieldCapabilities) + } + if m.settings != nil { + fields = append(fields, node.FieldSettings) + } + if m.weight != nil { + fields = append(fields, node.FieldWeight) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *NodeMutation) Field(name string) (ent.Value, bool) { + switch name { + case node.FieldCreatedAt: + return m.CreatedAt() + case node.FieldUpdatedAt: + return m.UpdatedAt() + case node.FieldDeletedAt: + return m.DeletedAt() + case node.FieldStatus: + return m.Status() + case node.FieldName: + return m.Name() + case node.FieldType: + return m.GetType() + case node.FieldServer: + return m.Server() + case node.FieldSlaveKey: + return m.SlaveKey() + case node.FieldCapabilities: + return m.Capabilities() + case node.FieldSettings: + return m.Settings() + case node.FieldWeight: + return m.Weight() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *NodeMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case node.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case node.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case node.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case node.FieldStatus: + return m.OldStatus(ctx) + case node.FieldName: + return m.OldName(ctx) + case node.FieldType: + return m.OldType(ctx) + case node.FieldServer: + return m.OldServer(ctx) + case node.FieldSlaveKey: + return m.OldSlaveKey(ctx) + case node.FieldCapabilities: + return m.OldCapabilities(ctx) + case node.FieldSettings: + return m.OldSettings(ctx) + case node.FieldWeight: + return m.OldWeight(ctx) + } + return nil, fmt.Errorf("unknown Node field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *NodeMutation) SetField(name string, value ent.Value) error { + switch name { + case node.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case node.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case node.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case node.FieldStatus: + v, ok := value.(node.Status) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case node.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case node.FieldType: + v, ok := value.(node.Type) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case node.FieldServer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetServer(v) + return nil + case node.FieldSlaveKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSlaveKey(v) + return nil + case node.FieldCapabilities: + v, ok := value.(*boolset.BooleanSet) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCapabilities(v) + return nil + case node.FieldSettings: + v, ok := value.(*types.NodeSetting) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSettings(v) + return nil + case node.FieldWeight: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeight(v) + return nil + } + return fmt.Errorf("unknown Node field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *NodeMutation) AddedFields() []string { + var fields []string + if m.addweight != nil { + fields = append(fields, node.FieldWeight) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *NodeMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case node.FieldWeight: + return m.AddedWeight() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *NodeMutation) AddField(name string, value ent.Value) error { + switch name { + case node.FieldWeight: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeight(v) + return nil + } + return fmt.Errorf("unknown Node numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *NodeMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(node.FieldDeletedAt) { + fields = append(fields, node.FieldDeletedAt) + } + if m.FieldCleared(node.FieldServer) { + fields = append(fields, node.FieldServer) + } + if m.FieldCleared(node.FieldSlaveKey) { + fields = append(fields, node.FieldSlaveKey) + } + if m.FieldCleared(node.FieldSettings) { + fields = append(fields, node.FieldSettings) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *NodeMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *NodeMutation) ClearField(name string) error { + switch name { + case node.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case node.FieldServer: + m.ClearServer() + return nil + case node.FieldSlaveKey: + m.ClearSlaveKey() + return nil + case node.FieldSettings: + m.ClearSettings() + return nil + } + return fmt.Errorf("unknown Node nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *NodeMutation) ResetField(name string) error { + switch name { + case node.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case node.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case node.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case node.FieldStatus: + m.ResetStatus() + return nil + case node.FieldName: + m.ResetName() + return nil + case node.FieldType: + m.ResetType() + return nil + case node.FieldServer: + m.ResetServer() + return nil + case node.FieldSlaveKey: + m.ResetSlaveKey() + return nil + case node.FieldCapabilities: + m.ResetCapabilities() + return nil + case node.FieldSettings: + m.ResetSettings() + return nil + case node.FieldWeight: + m.ResetWeight() + return nil + } + return fmt.Errorf("unknown Node field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *NodeMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.storage_policy != nil { + edges = append(edges, node.EdgeStoragePolicy) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *NodeMutation) AddedIDs(name string) []ent.Value { + switch name { + case node.EdgeStoragePolicy: + ids := make([]ent.Value, 0, len(m.storage_policy)) + for id := range m.storage_policy { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *NodeMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedstorage_policy != nil { + edges = append(edges, node.EdgeStoragePolicy) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *NodeMutation) RemovedIDs(name string) []ent.Value { + switch name { + case node.EdgeStoragePolicy: + ids := make([]ent.Value, 0, len(m.removedstorage_policy)) + for id := range m.removedstorage_policy { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *NodeMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedstorage_policy { + edges = append(edges, node.EdgeStoragePolicy) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *NodeMutation) EdgeCleared(name string) bool { + switch name { + case node.EdgeStoragePolicy: + return m.clearedstorage_policy + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *NodeMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Node unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *NodeMutation) ResetEdge(name string) error { + switch name { + case node.EdgeStoragePolicy: + m.ResetStoragePolicy() + return nil + } + return fmt.Errorf("unknown Node edge %s", name) +} + +// PasskeyMutation represents an operation that mutates the Passkey nodes in the graph. +type PasskeyMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + credential_id *string + name *string + credential **webauthn.Credential + used_at *time.Time + clearedFields map[string]struct{} + user *int + cleareduser bool + done bool + oldValue func(context.Context) (*Passkey, error) + predicates []predicate.Passkey +} + +var _ ent.Mutation = (*PasskeyMutation)(nil) + +// passkeyOption allows management of the mutation configuration using functional options. +type passkeyOption func(*PasskeyMutation) + +// newPasskeyMutation creates new mutation for the Passkey entity. +func newPasskeyMutation(c config, op Op, opts ...passkeyOption) *PasskeyMutation { + m := &PasskeyMutation{ + config: c, + op: op, + typ: TypePasskey, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPasskeyID sets the ID field of the mutation. +func withPasskeyID(id int) passkeyOption { + return func(m *PasskeyMutation) { + var ( + err error + once sync.Once + value *Passkey + ) + m.oldValue = func(ctx context.Context) (*Passkey, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Passkey.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPasskey sets the old Passkey of the mutation. +func withPasskey(node *Passkey) passkeyOption { + return func(m *PasskeyMutation) { + m.oldValue = func(context.Context) (*Passkey, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PasskeyMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PasskeyMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PasskeyMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PasskeyMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Passkey.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *PasskeyMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PasskeyMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PasskeyMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PasskeyMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PasskeyMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PasskeyMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *PasskeyMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *PasskeyMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *PasskeyMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[passkey.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *PasskeyMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[passkey.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *PasskeyMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, passkey.FieldDeletedAt) +} + +// SetUserID sets the "user_id" field. +func (m *PasskeyMutation) SetUserID(i int) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *PasskeyMutation) UserID() (r int, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldUserID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *PasskeyMutation) ResetUserID() { + m.user = nil +} + +// SetCredentialID sets the "credential_id" field. +func (m *PasskeyMutation) SetCredentialID(s string) { + m.credential_id = &s +} + +// CredentialID returns the value of the "credential_id" field in the mutation. +func (m *PasskeyMutation) CredentialID() (r string, exists bool) { + v := m.credential_id + if v == nil { + return + } + return *v, true +} + +// OldCredentialID returns the old "credential_id" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldCredentialID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCredentialID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCredentialID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCredentialID: %w", err) + } + return oldValue.CredentialID, nil +} + +// ResetCredentialID resets all changes to the "credential_id" field. +func (m *PasskeyMutation) ResetCredentialID() { + m.credential_id = nil +} + +// SetName sets the "name" field. +func (m *PasskeyMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *PasskeyMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *PasskeyMutation) ResetName() { + m.name = nil +} + +// SetCredential sets the "credential" field. +func (m *PasskeyMutation) SetCredential(w *webauthn.Credential) { + m.credential = &w +} + +// Credential returns the value of the "credential" field in the mutation. +func (m *PasskeyMutation) Credential() (r *webauthn.Credential, exists bool) { + v := m.credential + if v == nil { + return + } + return *v, true +} + +// OldCredential returns the old "credential" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldCredential(ctx context.Context) (v *webauthn.Credential, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCredential is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCredential requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCredential: %w", err) + } + return oldValue.Credential, nil +} + +// ResetCredential resets all changes to the "credential" field. +func (m *PasskeyMutation) ResetCredential() { + m.credential = nil +} + +// SetUsedAt sets the "used_at" field. +func (m *PasskeyMutation) SetUsedAt(t time.Time) { + m.used_at = &t +} + +// UsedAt returns the value of the "used_at" field in the mutation. +func (m *PasskeyMutation) UsedAt() (r time.Time, exists bool) { + v := m.used_at + if v == nil { + return + } + return *v, true +} + +// OldUsedAt returns the old "used_at" field's value of the Passkey entity. +// If the Passkey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PasskeyMutation) OldUsedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsedAt: %w", err) + } + return oldValue.UsedAt, nil +} + +// ClearUsedAt clears the value of the "used_at" field. +func (m *PasskeyMutation) ClearUsedAt() { + m.used_at = nil + m.clearedFields[passkey.FieldUsedAt] = struct{}{} +} + +// UsedAtCleared returns if the "used_at" field was cleared in this mutation. +func (m *PasskeyMutation) UsedAtCleared() bool { + _, ok := m.clearedFields[passkey.FieldUsedAt] + return ok +} + +// ResetUsedAt resets all changes to the "used_at" field. +func (m *PasskeyMutation) ResetUsedAt() { + m.used_at = nil + delete(m.clearedFields, passkey.FieldUsedAt) +} + +// ClearUser clears the "user" edge to the User entity. +func (m *PasskeyMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[passkey.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *PasskeyMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *PasskeyMutation) UserIDs() (ids []int) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *PasskeyMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the PasskeyMutation builder. +func (m *PasskeyMutation) Where(ps ...predicate.Passkey) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PasskeyMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PasskeyMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Passkey, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PasskeyMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PasskeyMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Passkey). +func (m *PasskeyMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PasskeyMutation) Fields() []string { + fields := make([]string, 0, 8) + if m.created_at != nil { + fields = append(fields, passkey.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, passkey.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, passkey.FieldDeletedAt) + } + if m.user != nil { + fields = append(fields, passkey.FieldUserID) + } + if m.credential_id != nil { + fields = append(fields, passkey.FieldCredentialID) + } + if m.name != nil { + fields = append(fields, passkey.FieldName) + } + if m.credential != nil { + fields = append(fields, passkey.FieldCredential) + } + if m.used_at != nil { + fields = append(fields, passkey.FieldUsedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PasskeyMutation) Field(name string) (ent.Value, bool) { + switch name { + case passkey.FieldCreatedAt: + return m.CreatedAt() + case passkey.FieldUpdatedAt: + return m.UpdatedAt() + case passkey.FieldDeletedAt: + return m.DeletedAt() + case passkey.FieldUserID: + return m.UserID() + case passkey.FieldCredentialID: + return m.CredentialID() + case passkey.FieldName: + return m.Name() + case passkey.FieldCredential: + return m.Credential() + case passkey.FieldUsedAt: + return m.UsedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PasskeyMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case passkey.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case passkey.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case passkey.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case passkey.FieldUserID: + return m.OldUserID(ctx) + case passkey.FieldCredentialID: + return m.OldCredentialID(ctx) + case passkey.FieldName: + return m.OldName(ctx) + case passkey.FieldCredential: + return m.OldCredential(ctx) + case passkey.FieldUsedAt: + return m.OldUsedAt(ctx) + } + return nil, fmt.Errorf("unknown Passkey field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PasskeyMutation) SetField(name string, value ent.Value) error { + switch name { + case passkey.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case passkey.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case passkey.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case passkey.FieldUserID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case passkey.FieldCredentialID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCredentialID(v) + return nil + case passkey.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case passkey.FieldCredential: + v, ok := value.(*webauthn.Credential) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCredential(v) + return nil + case passkey.FieldUsedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsedAt(v) + return nil + } + return fmt.Errorf("unknown Passkey field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PasskeyMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PasskeyMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PasskeyMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Passkey numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PasskeyMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(passkey.FieldDeletedAt) { + fields = append(fields, passkey.FieldDeletedAt) + } + if m.FieldCleared(passkey.FieldUsedAt) { + fields = append(fields, passkey.FieldUsedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PasskeyMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PasskeyMutation) ClearField(name string) error { + switch name { + case passkey.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case passkey.FieldUsedAt: + m.ClearUsedAt() + return nil + } + return fmt.Errorf("unknown Passkey nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PasskeyMutation) ResetField(name string) error { + switch name { + case passkey.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case passkey.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case passkey.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case passkey.FieldUserID: + m.ResetUserID() + return nil + case passkey.FieldCredentialID: + m.ResetCredentialID() + return nil + case passkey.FieldName: + m.ResetName() + return nil + case passkey.FieldCredential: + m.ResetCredential() + return nil + case passkey.FieldUsedAt: + m.ResetUsedAt() + return nil + } + return fmt.Errorf("unknown Passkey field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PasskeyMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, passkey.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PasskeyMutation) AddedIDs(name string) []ent.Value { + switch name { + case passkey.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PasskeyMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PasskeyMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PasskeyMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, passkey.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PasskeyMutation) EdgeCleared(name string) bool { + switch name { + case passkey.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PasskeyMutation) ClearEdge(name string) error { + switch name { + case passkey.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown Passkey unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PasskeyMutation) ResetEdge(name string) error { + switch name { + case passkey.EdgeUser: + m.ResetUser() + return nil + } + return fmt.Errorf("unknown Passkey edge %s", name) +} + +// SettingMutation represents an operation that mutates the Setting nodes in the graph. +type SettingMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + value *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*Setting, error) + predicates []predicate.Setting +} + +var _ ent.Mutation = (*SettingMutation)(nil) + +// settingOption allows management of the mutation configuration using functional options. +type settingOption func(*SettingMutation) + +// newSettingMutation creates new mutation for the Setting entity. +func newSettingMutation(c config, op Op, opts ...settingOption) *SettingMutation { + m := &SettingMutation{ + config: c, + op: op, + typ: TypeSetting, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSettingID sets the ID field of the mutation. +func withSettingID(id int) settingOption { + return func(m *SettingMutation) { + var ( + err error + once sync.Once + value *Setting + ) + m.oldValue = func(ctx context.Context) (*Setting, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Setting.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSetting sets the old Setting of the mutation. +func withSetting(node *Setting) settingOption { + return func(m *SettingMutation) { + m.oldValue = func(context.Context) (*Setting, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SettingMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SettingMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SettingMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SettingMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Setting.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SettingMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SettingMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Setting entity. +// If the Setting object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SettingMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SettingMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SettingMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SettingMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Setting entity. +// If the Setting object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SettingMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SettingMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *SettingMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *SettingMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Setting entity. +// If the Setting object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SettingMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *SettingMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[setting.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *SettingMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[setting.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *SettingMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, setting.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *SettingMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *SettingMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Setting entity. +// If the Setting object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SettingMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *SettingMutation) ResetName() { + m.name = nil +} + +// SetValue sets the "value" field. +func (m *SettingMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *SettingMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the Setting entity. +// If the Setting object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SettingMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ClearValue clears the value of the "value" field. +func (m *SettingMutation) ClearValue() { + m.value = nil + m.clearedFields[setting.FieldValue] = struct{}{} +} + +// ValueCleared returns if the "value" field was cleared in this mutation. +func (m *SettingMutation) ValueCleared() bool { + _, ok := m.clearedFields[setting.FieldValue] + return ok +} + +// ResetValue resets all changes to the "value" field. +func (m *SettingMutation) ResetValue() { + m.value = nil + delete(m.clearedFields, setting.FieldValue) +} + +// Where appends a list predicates to the SettingMutation builder. +func (m *SettingMutation) Where(ps ...predicate.Setting) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SettingMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SettingMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Setting, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SettingMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SettingMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Setting). +func (m *SettingMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SettingMutation) Fields() []string { + fields := make([]string, 0, 5) + if m.created_at != nil { + fields = append(fields, setting.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, setting.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, setting.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, setting.FieldName) + } + if m.value != nil { + fields = append(fields, setting.FieldValue) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SettingMutation) Field(name string) (ent.Value, bool) { + switch name { + case setting.FieldCreatedAt: + return m.CreatedAt() + case setting.FieldUpdatedAt: + return m.UpdatedAt() + case setting.FieldDeletedAt: + return m.DeletedAt() + case setting.FieldName: + return m.Name() + case setting.FieldValue: + return m.Value() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SettingMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case setting.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case setting.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case setting.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case setting.FieldName: + return m.OldName(ctx) + case setting.FieldValue: + return m.OldValue(ctx) + } + return nil, fmt.Errorf("unknown Setting field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SettingMutation) SetField(name string, value ent.Value) error { + switch name { + case setting.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case setting.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case setting.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case setting.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case setting.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + } + return fmt.Errorf("unknown Setting field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SettingMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SettingMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SettingMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Setting numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SettingMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(setting.FieldDeletedAt) { + fields = append(fields, setting.FieldDeletedAt) + } + if m.FieldCleared(setting.FieldValue) { + fields = append(fields, setting.FieldValue) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SettingMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SettingMutation) ClearField(name string) error { + switch name { + case setting.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case setting.FieldValue: + m.ClearValue() + return nil + } + return fmt.Errorf("unknown Setting nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SettingMutation) ResetField(name string) error { + switch name { + case setting.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case setting.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case setting.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case setting.FieldName: + m.ResetName() + return nil + case setting.FieldValue: + m.ResetValue() + return nil + } + return fmt.Errorf("unknown Setting field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SettingMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SettingMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SettingMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SettingMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SettingMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SettingMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SettingMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown Setting unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SettingMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown Setting edge %s", name) +} + +// ShareMutation represents an operation that mutates the Share nodes in the graph. +type ShareMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + password *string + views *int + addviews *int + downloads *int + adddownloads *int + expires *time.Time + remain_downloads *int + addremain_downloads *int + clearedFields map[string]struct{} + user *int + cleareduser bool + file *int + clearedfile bool + done bool + oldValue func(context.Context) (*Share, error) + predicates []predicate.Share +} + +var _ ent.Mutation = (*ShareMutation)(nil) + +// shareOption allows management of the mutation configuration using functional options. +type shareOption func(*ShareMutation) + +// newShareMutation creates new mutation for the Share entity. +func newShareMutation(c config, op Op, opts ...shareOption) *ShareMutation { + m := &ShareMutation{ + config: c, + op: op, + typ: TypeShare, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withShareID sets the ID field of the mutation. +func withShareID(id int) shareOption { + return func(m *ShareMutation) { + var ( + err error + once sync.Once + value *Share + ) + m.oldValue = func(ctx context.Context) (*Share, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Share.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withShare sets the old Share of the mutation. +func withShare(node *Share) shareOption { + return func(m *ShareMutation) { + m.oldValue = func(context.Context) (*Share, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ShareMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ShareMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ShareMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ShareMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Share.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *ShareMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ShareMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ShareMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *ShareMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ShareMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ShareMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *ShareMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *ShareMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *ShareMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[share.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *ShareMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[share.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *ShareMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, share.FieldDeletedAt) +} + +// SetPassword sets the "password" field. +func (m *ShareMutation) SetPassword(s string) { + m.password = &s +} + +// Password returns the value of the "password" field in the mutation. +func (m *ShareMutation) Password() (r string, exists bool) { + v := m.password + if v == nil { + return + } + return *v, true +} + +// OldPassword returns the old "password" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldPassword(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassword is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassword requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassword: %w", err) + } + return oldValue.Password, nil +} + +// ClearPassword clears the value of the "password" field. +func (m *ShareMutation) ClearPassword() { + m.password = nil + m.clearedFields[share.FieldPassword] = struct{}{} +} + +// PasswordCleared returns if the "password" field was cleared in this mutation. +func (m *ShareMutation) PasswordCleared() bool { + _, ok := m.clearedFields[share.FieldPassword] + return ok +} + +// ResetPassword resets all changes to the "password" field. +func (m *ShareMutation) ResetPassword() { + m.password = nil + delete(m.clearedFields, share.FieldPassword) +} + +// SetViews sets the "views" field. +func (m *ShareMutation) SetViews(i int) { + m.views = &i + m.addviews = nil +} + +// Views returns the value of the "views" field in the mutation. +func (m *ShareMutation) Views() (r int, exists bool) { + v := m.views + if v == nil { + return + } + return *v, true +} + +// OldViews returns the old "views" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldViews(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldViews is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldViews requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldViews: %w", err) + } + return oldValue.Views, nil +} + +// AddViews adds i to the "views" field. +func (m *ShareMutation) AddViews(i int) { + if m.addviews != nil { + *m.addviews += i + } else { + m.addviews = &i + } +} + +// AddedViews returns the value that was added to the "views" field in this mutation. +func (m *ShareMutation) AddedViews() (r int, exists bool) { + v := m.addviews + if v == nil { + return + } + return *v, true +} + +// ResetViews resets all changes to the "views" field. +func (m *ShareMutation) ResetViews() { + m.views = nil + m.addviews = nil +} + +// SetDownloads sets the "downloads" field. +func (m *ShareMutation) SetDownloads(i int) { + m.downloads = &i + m.adddownloads = nil +} + +// Downloads returns the value of the "downloads" field in the mutation. +func (m *ShareMutation) Downloads() (r int, exists bool) { + v := m.downloads + if v == nil { + return + } + return *v, true +} + +// OldDownloads returns the old "downloads" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldDownloads(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDownloads is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDownloads requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDownloads: %w", err) + } + return oldValue.Downloads, nil +} + +// AddDownloads adds i to the "downloads" field. +func (m *ShareMutation) AddDownloads(i int) { + if m.adddownloads != nil { + *m.adddownloads += i + } else { + m.adddownloads = &i + } +} + +// AddedDownloads returns the value that was added to the "downloads" field in this mutation. +func (m *ShareMutation) AddedDownloads() (r int, exists bool) { + v := m.adddownloads + if v == nil { + return + } + return *v, true +} + +// ResetDownloads resets all changes to the "downloads" field. +func (m *ShareMutation) ResetDownloads() { + m.downloads = nil + m.adddownloads = nil +} + +// SetExpires sets the "expires" field. +func (m *ShareMutation) SetExpires(t time.Time) { + m.expires = &t +} + +// Expires returns the value of the "expires" field in the mutation. +func (m *ShareMutation) Expires() (r time.Time, exists bool) { + v := m.expires + if v == nil { + return + } + return *v, true +} + +// OldExpires returns the old "expires" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldExpires(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpires is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpires requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpires: %w", err) + } + return oldValue.Expires, nil +} + +// ClearExpires clears the value of the "expires" field. +func (m *ShareMutation) ClearExpires() { + m.expires = nil + m.clearedFields[share.FieldExpires] = struct{}{} +} + +// ExpiresCleared returns if the "expires" field was cleared in this mutation. +func (m *ShareMutation) ExpiresCleared() bool { + _, ok := m.clearedFields[share.FieldExpires] + return ok +} + +// ResetExpires resets all changes to the "expires" field. +func (m *ShareMutation) ResetExpires() { + m.expires = nil + delete(m.clearedFields, share.FieldExpires) +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (m *ShareMutation) SetRemainDownloads(i int) { + m.remain_downloads = &i + m.addremain_downloads = nil +} + +// RemainDownloads returns the value of the "remain_downloads" field in the mutation. +func (m *ShareMutation) RemainDownloads() (r int, exists bool) { + v := m.remain_downloads + if v == nil { + return + } + return *v, true +} + +// OldRemainDownloads returns the old "remain_downloads" field's value of the Share entity. +// If the Share object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ShareMutation) OldRemainDownloads(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRemainDownloads is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRemainDownloads requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRemainDownloads: %w", err) + } + return oldValue.RemainDownloads, nil +} + +// AddRemainDownloads adds i to the "remain_downloads" field. +func (m *ShareMutation) AddRemainDownloads(i int) { + if m.addremain_downloads != nil { + *m.addremain_downloads += i + } else { + m.addremain_downloads = &i + } +} + +// AddedRemainDownloads returns the value that was added to the "remain_downloads" field in this mutation. +func (m *ShareMutation) AddedRemainDownloads() (r int, exists bool) { + v := m.addremain_downloads + if v == nil { + return + } + return *v, true +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (m *ShareMutation) ClearRemainDownloads() { + m.remain_downloads = nil + m.addremain_downloads = nil + m.clearedFields[share.FieldRemainDownloads] = struct{}{} +} + +// RemainDownloadsCleared returns if the "remain_downloads" field was cleared in this mutation. +func (m *ShareMutation) RemainDownloadsCleared() bool { + _, ok := m.clearedFields[share.FieldRemainDownloads] + return ok +} + +// ResetRemainDownloads resets all changes to the "remain_downloads" field. +func (m *ShareMutation) ResetRemainDownloads() { + m.remain_downloads = nil + m.addremain_downloads = nil + delete(m.clearedFields, share.FieldRemainDownloads) +} + +// SetUserID sets the "user" edge to the User entity by id. +func (m *ShareMutation) SetUserID(id int) { + m.user = &id +} + +// ClearUser clears the "user" edge to the User entity. +func (m *ShareMutation) ClearUser() { + m.cleareduser = true +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *ShareMutation) UserCleared() bool { + return m.cleareduser +} + +// UserID returns the "user" edge ID in the mutation. +func (m *ShareMutation) UserID() (id int, exists bool) { + if m.user != nil { + return *m.user, true + } + return +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *ShareMutation) UserIDs() (ids []int) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *ShareMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// SetFileID sets the "file" edge to the File entity by id. +func (m *ShareMutation) SetFileID(id int) { + m.file = &id +} + +// ClearFile clears the "file" edge to the File entity. +func (m *ShareMutation) ClearFile() { + m.clearedfile = true +} + +// FileCleared reports if the "file" edge to the File entity was cleared. +func (m *ShareMutation) FileCleared() bool { + return m.clearedfile +} + +// FileID returns the "file" edge ID in the mutation. +func (m *ShareMutation) FileID() (id int, exists bool) { + if m.file != nil { + return *m.file, true + } + return +} + +// FileIDs returns the "file" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// FileID instead. It exists only for internal usage by the builders. +func (m *ShareMutation) FileIDs() (ids []int) { + if id := m.file; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetFile resets all changes to the "file" edge. +func (m *ShareMutation) ResetFile() { + m.file = nil + m.clearedfile = false +} + +// Where appends a list predicates to the ShareMutation builder. +func (m *ShareMutation) Where(ps ...predicate.Share) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the ShareMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ShareMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Share, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *ShareMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *ShareMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Share). +func (m *ShareMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ShareMutation) Fields() []string { + fields := make([]string, 0, 8) + if m.created_at != nil { + fields = append(fields, share.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, share.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, share.FieldDeletedAt) + } + if m.password != nil { + fields = append(fields, share.FieldPassword) + } + if m.views != nil { + fields = append(fields, share.FieldViews) + } + if m.downloads != nil { + fields = append(fields, share.FieldDownloads) + } + if m.expires != nil { + fields = append(fields, share.FieldExpires) + } + if m.remain_downloads != nil { + fields = append(fields, share.FieldRemainDownloads) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ShareMutation) Field(name string) (ent.Value, bool) { + switch name { + case share.FieldCreatedAt: + return m.CreatedAt() + case share.FieldUpdatedAt: + return m.UpdatedAt() + case share.FieldDeletedAt: + return m.DeletedAt() + case share.FieldPassword: + return m.Password() + case share.FieldViews: + return m.Views() + case share.FieldDownloads: + return m.Downloads() + case share.FieldExpires: + return m.Expires() + case share.FieldRemainDownloads: + return m.RemainDownloads() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ShareMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case share.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case share.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case share.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case share.FieldPassword: + return m.OldPassword(ctx) + case share.FieldViews: + return m.OldViews(ctx) + case share.FieldDownloads: + return m.OldDownloads(ctx) + case share.FieldExpires: + return m.OldExpires(ctx) + case share.FieldRemainDownloads: + return m.OldRemainDownloads(ctx) + } + return nil, fmt.Errorf("unknown Share field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ShareMutation) SetField(name string, value ent.Value) error { + switch name { + case share.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case share.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case share.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case share.FieldPassword: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassword(v) + return nil + case share.FieldViews: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetViews(v) + return nil + case share.FieldDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDownloads(v) + return nil + case share.FieldExpires: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpires(v) + return nil + case share.FieldRemainDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRemainDownloads(v) + return nil + } + return fmt.Errorf("unknown Share field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ShareMutation) AddedFields() []string { + var fields []string + if m.addviews != nil { + fields = append(fields, share.FieldViews) + } + if m.adddownloads != nil { + fields = append(fields, share.FieldDownloads) + } + if m.addremain_downloads != nil { + fields = append(fields, share.FieldRemainDownloads) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ShareMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case share.FieldViews: + return m.AddedViews() + case share.FieldDownloads: + return m.AddedDownloads() + case share.FieldRemainDownloads: + return m.AddedRemainDownloads() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ShareMutation) AddField(name string, value ent.Value) error { + switch name { + case share.FieldViews: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddViews(v) + return nil + case share.FieldDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDownloads(v) + return nil + case share.FieldRemainDownloads: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRemainDownloads(v) + return nil + } + return fmt.Errorf("unknown Share numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ShareMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(share.FieldDeletedAt) { + fields = append(fields, share.FieldDeletedAt) + } + if m.FieldCleared(share.FieldPassword) { + fields = append(fields, share.FieldPassword) + } + if m.FieldCleared(share.FieldExpires) { + fields = append(fields, share.FieldExpires) + } + if m.FieldCleared(share.FieldRemainDownloads) { + fields = append(fields, share.FieldRemainDownloads) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ShareMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ShareMutation) ClearField(name string) error { + switch name { + case share.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case share.FieldPassword: + m.ClearPassword() + return nil + case share.FieldExpires: + m.ClearExpires() + return nil + case share.FieldRemainDownloads: + m.ClearRemainDownloads() + return nil + } + return fmt.Errorf("unknown Share nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ShareMutation) ResetField(name string) error { + switch name { + case share.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case share.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case share.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case share.FieldPassword: + m.ResetPassword() + return nil + case share.FieldViews: + m.ResetViews() + return nil + case share.FieldDownloads: + m.ResetDownloads() + return nil + case share.FieldExpires: + m.ResetExpires() + return nil + case share.FieldRemainDownloads: + m.ResetRemainDownloads() + return nil + } + return fmt.Errorf("unknown Share field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ShareMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.user != nil { + edges = append(edges, share.EdgeUser) + } + if m.file != nil { + edges = append(edges, share.EdgeFile) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ShareMutation) AddedIDs(name string) []ent.Value { + switch name { + case share.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case share.EdgeFile: + if id := m.file; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ShareMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ShareMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ShareMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.cleareduser { + edges = append(edges, share.EdgeUser) + } + if m.clearedfile { + edges = append(edges, share.EdgeFile) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ShareMutation) EdgeCleared(name string) bool { + switch name { + case share.EdgeUser: + return m.cleareduser + case share.EdgeFile: + return m.clearedfile + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ShareMutation) ClearEdge(name string) error { + switch name { + case share.EdgeUser: + m.ClearUser() + return nil + case share.EdgeFile: + m.ClearFile() + return nil + } + return fmt.Errorf("unknown Share unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ShareMutation) ResetEdge(name string) error { + switch name { + case share.EdgeUser: + m.ResetUser() + return nil + case share.EdgeFile: + m.ResetFile() + return nil + } + return fmt.Errorf("unknown Share edge %s", name) +} + +// StoragePolicyMutation represents an operation that mutates the StoragePolicy nodes in the graph. +type StoragePolicyMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + _type *string + server *string + bucket_name *string + is_private *bool + access_key *string + secret_key *string + max_size *int64 + addmax_size *int64 + dir_name_rule *string + file_name_rule *string + settings **types.PolicySetting + clearedFields map[string]struct{} + users map[int]struct{} + removedusers map[int]struct{} + clearedusers bool + groups map[int]struct{} + removedgroups map[int]struct{} + clearedgroups bool + files map[int]struct{} + removedfiles map[int]struct{} + clearedfiles bool + entities map[int]struct{} + removedentities map[int]struct{} + clearedentities bool + node *int + clearednode bool + done bool + oldValue func(context.Context) (*StoragePolicy, error) + predicates []predicate.StoragePolicy +} + +var _ ent.Mutation = (*StoragePolicyMutation)(nil) + +// storagepolicyOption allows management of the mutation configuration using functional options. +type storagepolicyOption func(*StoragePolicyMutation) + +// newStoragePolicyMutation creates new mutation for the StoragePolicy entity. +func newStoragePolicyMutation(c config, op Op, opts ...storagepolicyOption) *StoragePolicyMutation { + m := &StoragePolicyMutation{ + config: c, + op: op, + typ: TypeStoragePolicy, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withStoragePolicyID sets the ID field of the mutation. +func withStoragePolicyID(id int) storagepolicyOption { + return func(m *StoragePolicyMutation) { + var ( + err error + once sync.Once + value *StoragePolicy + ) + m.oldValue = func(ctx context.Context) (*StoragePolicy, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().StoragePolicy.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withStoragePolicy sets the old StoragePolicy of the mutation. +func withStoragePolicy(node *StoragePolicy) storagepolicyOption { + return func(m *StoragePolicyMutation) { + m.oldValue = func(context.Context) (*StoragePolicy, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m StoragePolicyMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m StoragePolicyMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *StoragePolicyMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *StoragePolicyMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().StoragePolicy.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *StoragePolicyMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *StoragePolicyMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *StoragePolicyMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *StoragePolicyMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *StoragePolicyMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *StoragePolicyMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *StoragePolicyMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *StoragePolicyMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *StoragePolicyMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[storagepolicy.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *StoragePolicyMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *StoragePolicyMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, storagepolicy.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *StoragePolicyMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *StoragePolicyMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *StoragePolicyMutation) ResetName() { + m.name = nil +} + +// SetType sets the "type" field. +func (m *StoragePolicyMutation) SetType(s string) { + m._type = &s +} + +// GetType returns the value of the "type" field in the mutation. +func (m *StoragePolicyMutation) GetType() (r string, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// ResetType resets all changes to the "type" field. +func (m *StoragePolicyMutation) ResetType() { + m._type = nil +} + +// SetServer sets the "server" field. +func (m *StoragePolicyMutation) SetServer(s string) { + m.server = &s +} + +// Server returns the value of the "server" field in the mutation. +func (m *StoragePolicyMutation) Server() (r string, exists bool) { + v := m.server + if v == nil { + return + } + return *v, true +} + +// OldServer returns the old "server" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldServer(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldServer is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldServer requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldServer: %w", err) + } + return oldValue.Server, nil +} + +// ClearServer clears the value of the "server" field. +func (m *StoragePolicyMutation) ClearServer() { + m.server = nil + m.clearedFields[storagepolicy.FieldServer] = struct{}{} +} + +// ServerCleared returns if the "server" field was cleared in this mutation. +func (m *StoragePolicyMutation) ServerCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldServer] + return ok +} + +// ResetServer resets all changes to the "server" field. +func (m *StoragePolicyMutation) ResetServer() { + m.server = nil + delete(m.clearedFields, storagepolicy.FieldServer) +} + +// SetBucketName sets the "bucket_name" field. +func (m *StoragePolicyMutation) SetBucketName(s string) { + m.bucket_name = &s +} + +// BucketName returns the value of the "bucket_name" field in the mutation. +func (m *StoragePolicyMutation) BucketName() (r string, exists bool) { + v := m.bucket_name + if v == nil { + return + } + return *v, true +} + +// OldBucketName returns the old "bucket_name" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldBucketName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBucketName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBucketName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBucketName: %w", err) + } + return oldValue.BucketName, nil +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (m *StoragePolicyMutation) ClearBucketName() { + m.bucket_name = nil + m.clearedFields[storagepolicy.FieldBucketName] = struct{}{} +} + +// BucketNameCleared returns if the "bucket_name" field was cleared in this mutation. +func (m *StoragePolicyMutation) BucketNameCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldBucketName] + return ok +} + +// ResetBucketName resets all changes to the "bucket_name" field. +func (m *StoragePolicyMutation) ResetBucketName() { + m.bucket_name = nil + delete(m.clearedFields, storagepolicy.FieldBucketName) +} + +// SetIsPrivate sets the "is_private" field. +func (m *StoragePolicyMutation) SetIsPrivate(b bool) { + m.is_private = &b +} + +// IsPrivate returns the value of the "is_private" field in the mutation. +func (m *StoragePolicyMutation) IsPrivate() (r bool, exists bool) { + v := m.is_private + if v == nil { + return + } + return *v, true +} + +// OldIsPrivate returns the old "is_private" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldIsPrivate(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsPrivate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsPrivate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsPrivate: %w", err) + } + return oldValue.IsPrivate, nil +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (m *StoragePolicyMutation) ClearIsPrivate() { + m.is_private = nil + m.clearedFields[storagepolicy.FieldIsPrivate] = struct{}{} +} + +// IsPrivateCleared returns if the "is_private" field was cleared in this mutation. +func (m *StoragePolicyMutation) IsPrivateCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldIsPrivate] + return ok +} + +// ResetIsPrivate resets all changes to the "is_private" field. +func (m *StoragePolicyMutation) ResetIsPrivate() { + m.is_private = nil + delete(m.clearedFields, storagepolicy.FieldIsPrivate) +} + +// SetAccessKey sets the "access_key" field. +func (m *StoragePolicyMutation) SetAccessKey(s string) { + m.access_key = &s +} + +// AccessKey returns the value of the "access_key" field in the mutation. +func (m *StoragePolicyMutation) AccessKey() (r string, exists bool) { + v := m.access_key + if v == nil { + return + } + return *v, true +} + +// OldAccessKey returns the old "access_key" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldAccessKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccessKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccessKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccessKey: %w", err) + } + return oldValue.AccessKey, nil +} + +// ClearAccessKey clears the value of the "access_key" field. +func (m *StoragePolicyMutation) ClearAccessKey() { + m.access_key = nil + m.clearedFields[storagepolicy.FieldAccessKey] = struct{}{} +} + +// AccessKeyCleared returns if the "access_key" field was cleared in this mutation. +func (m *StoragePolicyMutation) AccessKeyCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldAccessKey] + return ok +} + +// ResetAccessKey resets all changes to the "access_key" field. +func (m *StoragePolicyMutation) ResetAccessKey() { + m.access_key = nil + delete(m.clearedFields, storagepolicy.FieldAccessKey) +} + +// SetSecretKey sets the "secret_key" field. +func (m *StoragePolicyMutation) SetSecretKey(s string) { + m.secret_key = &s +} + +// SecretKey returns the value of the "secret_key" field in the mutation. +func (m *StoragePolicyMutation) SecretKey() (r string, exists bool) { + v := m.secret_key + if v == nil { + return + } + return *v, true +} + +// OldSecretKey returns the old "secret_key" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldSecretKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSecretKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSecretKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSecretKey: %w", err) + } + return oldValue.SecretKey, nil +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (m *StoragePolicyMutation) ClearSecretKey() { + m.secret_key = nil + m.clearedFields[storagepolicy.FieldSecretKey] = struct{}{} +} + +// SecretKeyCleared returns if the "secret_key" field was cleared in this mutation. +func (m *StoragePolicyMutation) SecretKeyCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldSecretKey] + return ok +} + +// ResetSecretKey resets all changes to the "secret_key" field. +func (m *StoragePolicyMutation) ResetSecretKey() { + m.secret_key = nil + delete(m.clearedFields, storagepolicy.FieldSecretKey) +} + +// SetMaxSize sets the "max_size" field. +func (m *StoragePolicyMutation) SetMaxSize(i int64) { + m.max_size = &i + m.addmax_size = nil +} + +// MaxSize returns the value of the "max_size" field in the mutation. +func (m *StoragePolicyMutation) MaxSize() (r int64, exists bool) { + v := m.max_size + if v == nil { + return + } + return *v, true +} + +// OldMaxSize returns the old "max_size" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldMaxSize(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMaxSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMaxSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMaxSize: %w", err) + } + return oldValue.MaxSize, nil +} + +// AddMaxSize adds i to the "max_size" field. +func (m *StoragePolicyMutation) AddMaxSize(i int64) { + if m.addmax_size != nil { + *m.addmax_size += i + } else { + m.addmax_size = &i + } +} + +// AddedMaxSize returns the value that was added to the "max_size" field in this mutation. +func (m *StoragePolicyMutation) AddedMaxSize() (r int64, exists bool) { + v := m.addmax_size + if v == nil { + return + } + return *v, true +} + +// ClearMaxSize clears the value of the "max_size" field. +func (m *StoragePolicyMutation) ClearMaxSize() { + m.max_size = nil + m.addmax_size = nil + m.clearedFields[storagepolicy.FieldMaxSize] = struct{}{} +} + +// MaxSizeCleared returns if the "max_size" field was cleared in this mutation. +func (m *StoragePolicyMutation) MaxSizeCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldMaxSize] + return ok +} + +// ResetMaxSize resets all changes to the "max_size" field. +func (m *StoragePolicyMutation) ResetMaxSize() { + m.max_size = nil + m.addmax_size = nil + delete(m.clearedFields, storagepolicy.FieldMaxSize) +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (m *StoragePolicyMutation) SetDirNameRule(s string) { + m.dir_name_rule = &s +} + +// DirNameRule returns the value of the "dir_name_rule" field in the mutation. +func (m *StoragePolicyMutation) DirNameRule() (r string, exists bool) { + v := m.dir_name_rule + if v == nil { + return + } + return *v, true +} + +// OldDirNameRule returns the old "dir_name_rule" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldDirNameRule(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDirNameRule is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDirNameRule requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDirNameRule: %w", err) + } + return oldValue.DirNameRule, nil +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (m *StoragePolicyMutation) ClearDirNameRule() { + m.dir_name_rule = nil + m.clearedFields[storagepolicy.FieldDirNameRule] = struct{}{} +} + +// DirNameRuleCleared returns if the "dir_name_rule" field was cleared in this mutation. +func (m *StoragePolicyMutation) DirNameRuleCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldDirNameRule] + return ok +} + +// ResetDirNameRule resets all changes to the "dir_name_rule" field. +func (m *StoragePolicyMutation) ResetDirNameRule() { + m.dir_name_rule = nil + delete(m.clearedFields, storagepolicy.FieldDirNameRule) +} + +// SetFileNameRule sets the "file_name_rule" field. +func (m *StoragePolicyMutation) SetFileNameRule(s string) { + m.file_name_rule = &s +} + +// FileNameRule returns the value of the "file_name_rule" field in the mutation. +func (m *StoragePolicyMutation) FileNameRule() (r string, exists bool) { + v := m.file_name_rule + if v == nil { + return + } + return *v, true +} + +// OldFileNameRule returns the old "file_name_rule" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldFileNameRule(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFileNameRule is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFileNameRule requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFileNameRule: %w", err) + } + return oldValue.FileNameRule, nil +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (m *StoragePolicyMutation) ClearFileNameRule() { + m.file_name_rule = nil + m.clearedFields[storagepolicy.FieldFileNameRule] = struct{}{} +} + +// FileNameRuleCleared returns if the "file_name_rule" field was cleared in this mutation. +func (m *StoragePolicyMutation) FileNameRuleCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldFileNameRule] + return ok +} + +// ResetFileNameRule resets all changes to the "file_name_rule" field. +func (m *StoragePolicyMutation) ResetFileNameRule() { + m.file_name_rule = nil + delete(m.clearedFields, storagepolicy.FieldFileNameRule) +} + +// SetSettings sets the "settings" field. +func (m *StoragePolicyMutation) SetSettings(ts *types.PolicySetting) { + m.settings = &ts +} + +// Settings returns the value of the "settings" field in the mutation. +func (m *StoragePolicyMutation) Settings() (r *types.PolicySetting, exists bool) { + v := m.settings + if v == nil { + return + } + return *v, true +} + +// OldSettings returns the old "settings" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldSettings(ctx context.Context) (v *types.PolicySetting, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSettings is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSettings requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSettings: %w", err) + } + return oldValue.Settings, nil +} + +// ClearSettings clears the value of the "settings" field. +func (m *StoragePolicyMutation) ClearSettings() { + m.settings = nil + m.clearedFields[storagepolicy.FieldSettings] = struct{}{} +} + +// SettingsCleared returns if the "settings" field was cleared in this mutation. +func (m *StoragePolicyMutation) SettingsCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldSettings] + return ok +} + +// ResetSettings resets all changes to the "settings" field. +func (m *StoragePolicyMutation) ResetSettings() { + m.settings = nil + delete(m.clearedFields, storagepolicy.FieldSettings) +} + +// SetNodeID sets the "node_id" field. +func (m *StoragePolicyMutation) SetNodeID(i int) { + m.node = &i +} + +// NodeID returns the value of the "node_id" field in the mutation. +func (m *StoragePolicyMutation) NodeID() (r int, exists bool) { + v := m.node + if v == nil { + return + } + return *v, true +} + +// OldNodeID returns the old "node_id" field's value of the StoragePolicy entity. +// If the StoragePolicy object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *StoragePolicyMutation) OldNodeID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldNodeID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldNodeID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldNodeID: %w", err) + } + return oldValue.NodeID, nil +} + +// ClearNodeID clears the value of the "node_id" field. +func (m *StoragePolicyMutation) ClearNodeID() { + m.node = nil + m.clearedFields[storagepolicy.FieldNodeID] = struct{}{} +} + +// NodeIDCleared returns if the "node_id" field was cleared in this mutation. +func (m *StoragePolicyMutation) NodeIDCleared() bool { + _, ok := m.clearedFields[storagepolicy.FieldNodeID] + return ok +} + +// ResetNodeID resets all changes to the "node_id" field. +func (m *StoragePolicyMutation) ResetNodeID() { + m.node = nil + delete(m.clearedFields, storagepolicy.FieldNodeID) +} + +// AddUserIDs adds the "users" edge to the User entity by ids. +func (m *StoragePolicyMutation) AddUserIDs(ids ...int) { + if m.users == nil { + m.users = make(map[int]struct{}) + } + for i := range ids { + m.users[ids[i]] = struct{}{} + } +} + +// ClearUsers clears the "users" edge to the User entity. +func (m *StoragePolicyMutation) ClearUsers() { + m.clearedusers = true +} + +// UsersCleared reports if the "users" edge to the User entity was cleared. +func (m *StoragePolicyMutation) UsersCleared() bool { + return m.clearedusers +} + +// RemoveUserIDs removes the "users" edge to the User entity by IDs. +func (m *StoragePolicyMutation) RemoveUserIDs(ids ...int) { + if m.removedusers == nil { + m.removedusers = make(map[int]struct{}) + } + for i := range ids { + delete(m.users, ids[i]) + m.removedusers[ids[i]] = struct{}{} + } +} + +// RemovedUsers returns the removed IDs of the "users" edge to the User entity. +func (m *StoragePolicyMutation) RemovedUsersIDs() (ids []int) { + for id := range m.removedusers { + ids = append(ids, id) + } + return +} + +// UsersIDs returns the "users" edge IDs in the mutation. +func (m *StoragePolicyMutation) UsersIDs() (ids []int) { + for id := range m.users { + ids = append(ids, id) + } + return +} + +// ResetUsers resets all changes to the "users" edge. +func (m *StoragePolicyMutation) ResetUsers() { + m.users = nil + m.clearedusers = false + m.removedusers = nil +} + +// AddGroupIDs adds the "groups" edge to the Group entity by ids. +func (m *StoragePolicyMutation) AddGroupIDs(ids ...int) { + if m.groups == nil { + m.groups = make(map[int]struct{}) + } + for i := range ids { + m.groups[ids[i]] = struct{}{} + } +} + +// ClearGroups clears the "groups" edge to the Group entity. +func (m *StoragePolicyMutation) ClearGroups() { + m.clearedgroups = true +} + +// GroupsCleared reports if the "groups" edge to the Group entity was cleared. +func (m *StoragePolicyMutation) GroupsCleared() bool { + return m.clearedgroups +} + +// RemoveGroupIDs removes the "groups" edge to the Group entity by IDs. +func (m *StoragePolicyMutation) RemoveGroupIDs(ids ...int) { + if m.removedgroups == nil { + m.removedgroups = make(map[int]struct{}) + } + for i := range ids { + delete(m.groups, ids[i]) + m.removedgroups[ids[i]] = struct{}{} + } +} + +// RemovedGroups returns the removed IDs of the "groups" edge to the Group entity. +func (m *StoragePolicyMutation) RemovedGroupsIDs() (ids []int) { + for id := range m.removedgroups { + ids = append(ids, id) + } + return +} + +// GroupsIDs returns the "groups" edge IDs in the mutation. +func (m *StoragePolicyMutation) GroupsIDs() (ids []int) { + for id := range m.groups { + ids = append(ids, id) + } + return +} + +// ResetGroups resets all changes to the "groups" edge. +func (m *StoragePolicyMutation) ResetGroups() { + m.groups = nil + m.clearedgroups = false + m.removedgroups = nil +} + +// AddFileIDs adds the "files" edge to the File entity by ids. +func (m *StoragePolicyMutation) AddFileIDs(ids ...int) { + if m.files == nil { + m.files = make(map[int]struct{}) + } + for i := range ids { + m.files[ids[i]] = struct{}{} + } +} + +// ClearFiles clears the "files" edge to the File entity. +func (m *StoragePolicyMutation) ClearFiles() { + m.clearedfiles = true +} + +// FilesCleared reports if the "files" edge to the File entity was cleared. +func (m *StoragePolicyMutation) FilesCleared() bool { + return m.clearedfiles +} + +// RemoveFileIDs removes the "files" edge to the File entity by IDs. +func (m *StoragePolicyMutation) RemoveFileIDs(ids ...int) { + if m.removedfiles == nil { + m.removedfiles = make(map[int]struct{}) + } + for i := range ids { + delete(m.files, ids[i]) + m.removedfiles[ids[i]] = struct{}{} + } +} + +// RemovedFiles returns the removed IDs of the "files" edge to the File entity. +func (m *StoragePolicyMutation) RemovedFilesIDs() (ids []int) { + for id := range m.removedfiles { + ids = append(ids, id) + } + return +} + +// FilesIDs returns the "files" edge IDs in the mutation. +func (m *StoragePolicyMutation) FilesIDs() (ids []int) { + for id := range m.files { + ids = append(ids, id) + } + return +} + +// ResetFiles resets all changes to the "files" edge. +func (m *StoragePolicyMutation) ResetFiles() { + m.files = nil + m.clearedfiles = false + m.removedfiles = nil +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by ids. +func (m *StoragePolicyMutation) AddEntityIDs(ids ...int) { + if m.entities == nil { + m.entities = make(map[int]struct{}) + } + for i := range ids { + m.entities[ids[i]] = struct{}{} + } +} + +// ClearEntities clears the "entities" edge to the Entity entity. +func (m *StoragePolicyMutation) ClearEntities() { + m.clearedentities = true +} + +// EntitiesCleared reports if the "entities" edge to the Entity entity was cleared. +func (m *StoragePolicyMutation) EntitiesCleared() bool { + return m.clearedentities +} + +// RemoveEntityIDs removes the "entities" edge to the Entity entity by IDs. +func (m *StoragePolicyMutation) RemoveEntityIDs(ids ...int) { + if m.removedentities == nil { + m.removedentities = make(map[int]struct{}) + } + for i := range ids { + delete(m.entities, ids[i]) + m.removedentities[ids[i]] = struct{}{} + } +} + +// RemovedEntities returns the removed IDs of the "entities" edge to the Entity entity. +func (m *StoragePolicyMutation) RemovedEntitiesIDs() (ids []int) { + for id := range m.removedentities { + ids = append(ids, id) + } + return +} + +// EntitiesIDs returns the "entities" edge IDs in the mutation. +func (m *StoragePolicyMutation) EntitiesIDs() (ids []int) { + for id := range m.entities { + ids = append(ids, id) + } + return +} + +// ResetEntities resets all changes to the "entities" edge. +func (m *StoragePolicyMutation) ResetEntities() { + m.entities = nil + m.clearedentities = false + m.removedentities = nil +} + +// ClearNode clears the "node" edge to the Node entity. +func (m *StoragePolicyMutation) ClearNode() { + m.clearednode = true + m.clearedFields[storagepolicy.FieldNodeID] = struct{}{} +} + +// NodeCleared reports if the "node" edge to the Node entity was cleared. +func (m *StoragePolicyMutation) NodeCleared() bool { + return m.NodeIDCleared() || m.clearednode +} + +// NodeIDs returns the "node" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// NodeID instead. It exists only for internal usage by the builders. +func (m *StoragePolicyMutation) NodeIDs() (ids []int) { + if id := m.node; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetNode resets all changes to the "node" edge. +func (m *StoragePolicyMutation) ResetNode() { + m.node = nil + m.clearednode = false +} + +// Where appends a list predicates to the StoragePolicyMutation builder. +func (m *StoragePolicyMutation) Where(ps ...predicate.StoragePolicy) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the StoragePolicyMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *StoragePolicyMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.StoragePolicy, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *StoragePolicyMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *StoragePolicyMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (StoragePolicy). +func (m *StoragePolicyMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *StoragePolicyMutation) Fields() []string { + fields := make([]string, 0, 15) + if m.created_at != nil { + fields = append(fields, storagepolicy.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, storagepolicy.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, storagepolicy.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, storagepolicy.FieldName) + } + if m._type != nil { + fields = append(fields, storagepolicy.FieldType) + } + if m.server != nil { + fields = append(fields, storagepolicy.FieldServer) + } + if m.bucket_name != nil { + fields = append(fields, storagepolicy.FieldBucketName) + } + if m.is_private != nil { + fields = append(fields, storagepolicy.FieldIsPrivate) + } + if m.access_key != nil { + fields = append(fields, storagepolicy.FieldAccessKey) + } + if m.secret_key != nil { + fields = append(fields, storagepolicy.FieldSecretKey) + } + if m.max_size != nil { + fields = append(fields, storagepolicy.FieldMaxSize) + } + if m.dir_name_rule != nil { + fields = append(fields, storagepolicy.FieldDirNameRule) + } + if m.file_name_rule != nil { + fields = append(fields, storagepolicy.FieldFileNameRule) + } + if m.settings != nil { + fields = append(fields, storagepolicy.FieldSettings) + } + if m.node != nil { + fields = append(fields, storagepolicy.FieldNodeID) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *StoragePolicyMutation) Field(name string) (ent.Value, bool) { + switch name { + case storagepolicy.FieldCreatedAt: + return m.CreatedAt() + case storagepolicy.FieldUpdatedAt: + return m.UpdatedAt() + case storagepolicy.FieldDeletedAt: + return m.DeletedAt() + case storagepolicy.FieldName: + return m.Name() + case storagepolicy.FieldType: + return m.GetType() + case storagepolicy.FieldServer: + return m.Server() + case storagepolicy.FieldBucketName: + return m.BucketName() + case storagepolicy.FieldIsPrivate: + return m.IsPrivate() + case storagepolicy.FieldAccessKey: + return m.AccessKey() + case storagepolicy.FieldSecretKey: + return m.SecretKey() + case storagepolicy.FieldMaxSize: + return m.MaxSize() + case storagepolicy.FieldDirNameRule: + return m.DirNameRule() + case storagepolicy.FieldFileNameRule: + return m.FileNameRule() + case storagepolicy.FieldSettings: + return m.Settings() + case storagepolicy.FieldNodeID: + return m.NodeID() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *StoragePolicyMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case storagepolicy.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case storagepolicy.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case storagepolicy.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case storagepolicy.FieldName: + return m.OldName(ctx) + case storagepolicy.FieldType: + return m.OldType(ctx) + case storagepolicy.FieldServer: + return m.OldServer(ctx) + case storagepolicy.FieldBucketName: + return m.OldBucketName(ctx) + case storagepolicy.FieldIsPrivate: + return m.OldIsPrivate(ctx) + case storagepolicy.FieldAccessKey: + return m.OldAccessKey(ctx) + case storagepolicy.FieldSecretKey: + return m.OldSecretKey(ctx) + case storagepolicy.FieldMaxSize: + return m.OldMaxSize(ctx) + case storagepolicy.FieldDirNameRule: + return m.OldDirNameRule(ctx) + case storagepolicy.FieldFileNameRule: + return m.OldFileNameRule(ctx) + case storagepolicy.FieldSettings: + return m.OldSettings(ctx) + case storagepolicy.FieldNodeID: + return m.OldNodeID(ctx) + } + return nil, fmt.Errorf("unknown StoragePolicy field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *StoragePolicyMutation) SetField(name string, value ent.Value) error { + switch name { + case storagepolicy.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case storagepolicy.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case storagepolicy.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case storagepolicy.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case storagepolicy.FieldType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case storagepolicy.FieldServer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetServer(v) + return nil + case storagepolicy.FieldBucketName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBucketName(v) + return nil + case storagepolicy.FieldIsPrivate: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsPrivate(v) + return nil + case storagepolicy.FieldAccessKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccessKey(v) + return nil + case storagepolicy.FieldSecretKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSecretKey(v) + return nil + case storagepolicy.FieldMaxSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMaxSize(v) + return nil + case storagepolicy.FieldDirNameRule: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDirNameRule(v) + return nil + case storagepolicy.FieldFileNameRule: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFileNameRule(v) + return nil + case storagepolicy.FieldSettings: + v, ok := value.(*types.PolicySetting) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSettings(v) + return nil + case storagepolicy.FieldNodeID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetNodeID(v) + return nil + } + return fmt.Errorf("unknown StoragePolicy field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *StoragePolicyMutation) AddedFields() []string { + var fields []string + if m.addmax_size != nil { + fields = append(fields, storagepolicy.FieldMaxSize) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *StoragePolicyMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case storagepolicy.FieldMaxSize: + return m.AddedMaxSize() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *StoragePolicyMutation) AddField(name string, value ent.Value) error { + switch name { + case storagepolicy.FieldMaxSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMaxSize(v) + return nil + } + return fmt.Errorf("unknown StoragePolicy numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *StoragePolicyMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(storagepolicy.FieldDeletedAt) { + fields = append(fields, storagepolicy.FieldDeletedAt) + } + if m.FieldCleared(storagepolicy.FieldServer) { + fields = append(fields, storagepolicy.FieldServer) + } + if m.FieldCleared(storagepolicy.FieldBucketName) { + fields = append(fields, storagepolicy.FieldBucketName) + } + if m.FieldCleared(storagepolicy.FieldIsPrivate) { + fields = append(fields, storagepolicy.FieldIsPrivate) + } + if m.FieldCleared(storagepolicy.FieldAccessKey) { + fields = append(fields, storagepolicy.FieldAccessKey) + } + if m.FieldCleared(storagepolicy.FieldSecretKey) { + fields = append(fields, storagepolicy.FieldSecretKey) + } + if m.FieldCleared(storagepolicy.FieldMaxSize) { + fields = append(fields, storagepolicy.FieldMaxSize) + } + if m.FieldCleared(storagepolicy.FieldDirNameRule) { + fields = append(fields, storagepolicy.FieldDirNameRule) + } + if m.FieldCleared(storagepolicy.FieldFileNameRule) { + fields = append(fields, storagepolicy.FieldFileNameRule) + } + if m.FieldCleared(storagepolicy.FieldSettings) { + fields = append(fields, storagepolicy.FieldSettings) + } + if m.FieldCleared(storagepolicy.FieldNodeID) { + fields = append(fields, storagepolicy.FieldNodeID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *StoragePolicyMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *StoragePolicyMutation) ClearField(name string) error { + switch name { + case storagepolicy.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case storagepolicy.FieldServer: + m.ClearServer() + return nil + case storagepolicy.FieldBucketName: + m.ClearBucketName() + return nil + case storagepolicy.FieldIsPrivate: + m.ClearIsPrivate() + return nil + case storagepolicy.FieldAccessKey: + m.ClearAccessKey() + return nil + case storagepolicy.FieldSecretKey: + m.ClearSecretKey() + return nil + case storagepolicy.FieldMaxSize: + m.ClearMaxSize() + return nil + case storagepolicy.FieldDirNameRule: + m.ClearDirNameRule() + return nil + case storagepolicy.FieldFileNameRule: + m.ClearFileNameRule() + return nil + case storagepolicy.FieldSettings: + m.ClearSettings() + return nil + case storagepolicy.FieldNodeID: + m.ClearNodeID() + return nil + } + return fmt.Errorf("unknown StoragePolicy nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *StoragePolicyMutation) ResetField(name string) error { + switch name { + case storagepolicy.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case storagepolicy.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case storagepolicy.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case storagepolicy.FieldName: + m.ResetName() + return nil + case storagepolicy.FieldType: + m.ResetType() + return nil + case storagepolicy.FieldServer: + m.ResetServer() + return nil + case storagepolicy.FieldBucketName: + m.ResetBucketName() + return nil + case storagepolicy.FieldIsPrivate: + m.ResetIsPrivate() + return nil + case storagepolicy.FieldAccessKey: + m.ResetAccessKey() + return nil + case storagepolicy.FieldSecretKey: + m.ResetSecretKey() + return nil + case storagepolicy.FieldMaxSize: + m.ResetMaxSize() + return nil + case storagepolicy.FieldDirNameRule: + m.ResetDirNameRule() + return nil + case storagepolicy.FieldFileNameRule: + m.ResetFileNameRule() + return nil + case storagepolicy.FieldSettings: + m.ResetSettings() + return nil + case storagepolicy.FieldNodeID: + m.ResetNodeID() + return nil + } + return fmt.Errorf("unknown StoragePolicy field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *StoragePolicyMutation) AddedEdges() []string { + edges := make([]string, 0, 5) + if m.users != nil { + edges = append(edges, storagepolicy.EdgeUsers) + } + if m.groups != nil { + edges = append(edges, storagepolicy.EdgeGroups) + } + if m.files != nil { + edges = append(edges, storagepolicy.EdgeFiles) + } + if m.entities != nil { + edges = append(edges, storagepolicy.EdgeEntities) + } + if m.node != nil { + edges = append(edges, storagepolicy.EdgeNode) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *StoragePolicyMutation) AddedIDs(name string) []ent.Value { + switch name { + case storagepolicy.EdgeUsers: + ids := make([]ent.Value, 0, len(m.users)) + for id := range m.users { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeGroups: + ids := make([]ent.Value, 0, len(m.groups)) + for id := range m.groups { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeFiles: + ids := make([]ent.Value, 0, len(m.files)) + for id := range m.files { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeEntities: + ids := make([]ent.Value, 0, len(m.entities)) + for id := range m.entities { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeNode: + if id := m.node; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *StoragePolicyMutation) RemovedEdges() []string { + edges := make([]string, 0, 5) + if m.removedusers != nil { + edges = append(edges, storagepolicy.EdgeUsers) + } + if m.removedgroups != nil { + edges = append(edges, storagepolicy.EdgeGroups) + } + if m.removedfiles != nil { + edges = append(edges, storagepolicy.EdgeFiles) + } + if m.removedentities != nil { + edges = append(edges, storagepolicy.EdgeEntities) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *StoragePolicyMutation) RemovedIDs(name string) []ent.Value { + switch name { + case storagepolicy.EdgeUsers: + ids := make([]ent.Value, 0, len(m.removedusers)) + for id := range m.removedusers { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeGroups: + ids := make([]ent.Value, 0, len(m.removedgroups)) + for id := range m.removedgroups { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeFiles: + ids := make([]ent.Value, 0, len(m.removedfiles)) + for id := range m.removedfiles { + ids = append(ids, id) + } + return ids + case storagepolicy.EdgeEntities: + ids := make([]ent.Value, 0, len(m.removedentities)) + for id := range m.removedentities { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *StoragePolicyMutation) ClearedEdges() []string { + edges := make([]string, 0, 5) + if m.clearedusers { + edges = append(edges, storagepolicy.EdgeUsers) + } + if m.clearedgroups { + edges = append(edges, storagepolicy.EdgeGroups) + } + if m.clearedfiles { + edges = append(edges, storagepolicy.EdgeFiles) + } + if m.clearedentities { + edges = append(edges, storagepolicy.EdgeEntities) + } + if m.clearednode { + edges = append(edges, storagepolicy.EdgeNode) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *StoragePolicyMutation) EdgeCleared(name string) bool { + switch name { + case storagepolicy.EdgeUsers: + return m.clearedusers + case storagepolicy.EdgeGroups: + return m.clearedgroups + case storagepolicy.EdgeFiles: + return m.clearedfiles + case storagepolicy.EdgeEntities: + return m.clearedentities + case storagepolicy.EdgeNode: + return m.clearednode + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *StoragePolicyMutation) ClearEdge(name string) error { + switch name { + case storagepolicy.EdgeNode: + m.ClearNode() + return nil + } + return fmt.Errorf("unknown StoragePolicy unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *StoragePolicyMutation) ResetEdge(name string) error { + switch name { + case storagepolicy.EdgeUsers: + m.ResetUsers() + return nil + case storagepolicy.EdgeGroups: + m.ResetGroups() + return nil + case storagepolicy.EdgeFiles: + m.ResetFiles() + return nil + case storagepolicy.EdgeEntities: + m.ResetEntities() + return nil + case storagepolicy.EdgeNode: + m.ResetNode() + return nil + } + return fmt.Errorf("unknown StoragePolicy edge %s", name) +} + +// TaskMutation represents an operation that mutates the Task nodes in the graph. +type TaskMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + _type *string + status *task.Status + public_state **types.TaskPublicState + private_state *string + correlation_id *uuid.UUID + clearedFields map[string]struct{} + user *int + cleareduser bool + done bool + oldValue func(context.Context) (*Task, error) + predicates []predicate.Task +} + +var _ ent.Mutation = (*TaskMutation)(nil) + +// taskOption allows management of the mutation configuration using functional options. +type taskOption func(*TaskMutation) + +// newTaskMutation creates new mutation for the Task entity. +func newTaskMutation(c config, op Op, opts ...taskOption) *TaskMutation { + m := &TaskMutation{ + config: c, + op: op, + typ: TypeTask, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withTaskID sets the ID field of the mutation. +func withTaskID(id int) taskOption { + return func(m *TaskMutation) { + var ( + err error + once sync.Once + value *Task + ) + m.oldValue = func(ctx context.Context) (*Task, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Task.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withTask sets the old Task of the mutation. +func withTask(node *Task) taskOption { + return func(m *TaskMutation) { + m.oldValue = func(context.Context) (*Task, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m TaskMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m TaskMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *TaskMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *TaskMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Task.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *TaskMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *TaskMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *TaskMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *TaskMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *TaskMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *TaskMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *TaskMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *TaskMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *TaskMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[task.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *TaskMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[task.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *TaskMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, task.FieldDeletedAt) +} + +// SetType sets the "type" field. +func (m *TaskMutation) SetType(s string) { + m._type = &s +} + +// GetType returns the value of the "type" field in the mutation. +func (m *TaskMutation) GetType() (r string, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// ResetType resets all changes to the "type" field. +func (m *TaskMutation) ResetType() { + m._type = nil +} + +// SetStatus sets the "status" field. +func (m *TaskMutation) SetStatus(t task.Status) { + m.status = &t +} + +// Status returns the value of the "status" field in the mutation. +func (m *TaskMutation) Status() (r task.Status, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldStatus(ctx context.Context) (v task.Status, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *TaskMutation) ResetStatus() { + m.status = nil +} + +// SetPublicState sets the "public_state" field. +func (m *TaskMutation) SetPublicState(tps *types.TaskPublicState) { + m.public_state = &tps +} + +// PublicState returns the value of the "public_state" field in the mutation. +func (m *TaskMutation) PublicState() (r *types.TaskPublicState, exists bool) { + v := m.public_state + if v == nil { + return + } + return *v, true +} + +// OldPublicState returns the old "public_state" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldPublicState(ctx context.Context) (v *types.TaskPublicState, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPublicState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPublicState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPublicState: %w", err) + } + return oldValue.PublicState, nil +} + +// ResetPublicState resets all changes to the "public_state" field. +func (m *TaskMutation) ResetPublicState() { + m.public_state = nil +} + +// SetPrivateState sets the "private_state" field. +func (m *TaskMutation) SetPrivateState(s string) { + m.private_state = &s +} + +// PrivateState returns the value of the "private_state" field in the mutation. +func (m *TaskMutation) PrivateState() (r string, exists bool) { + v := m.private_state + if v == nil { + return + } + return *v, true +} + +// OldPrivateState returns the old "private_state" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldPrivateState(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrivateState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrivateState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrivateState: %w", err) + } + return oldValue.PrivateState, nil +} + +// ClearPrivateState clears the value of the "private_state" field. +func (m *TaskMutation) ClearPrivateState() { + m.private_state = nil + m.clearedFields[task.FieldPrivateState] = struct{}{} +} + +// PrivateStateCleared returns if the "private_state" field was cleared in this mutation. +func (m *TaskMutation) PrivateStateCleared() bool { + _, ok := m.clearedFields[task.FieldPrivateState] + return ok +} + +// ResetPrivateState resets all changes to the "private_state" field. +func (m *TaskMutation) ResetPrivateState() { + m.private_state = nil + delete(m.clearedFields, task.FieldPrivateState) +} + +// SetCorrelationID sets the "correlation_id" field. +func (m *TaskMutation) SetCorrelationID(u uuid.UUID) { + m.correlation_id = &u +} + +// CorrelationID returns the value of the "correlation_id" field in the mutation. +func (m *TaskMutation) CorrelationID() (r uuid.UUID, exists bool) { + v := m.correlation_id + if v == nil { + return + } + return *v, true +} + +// OldCorrelationID returns the old "correlation_id" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldCorrelationID(ctx context.Context) (v uuid.UUID, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCorrelationID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCorrelationID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCorrelationID: %w", err) + } + return oldValue.CorrelationID, nil +} + +// ClearCorrelationID clears the value of the "correlation_id" field. +func (m *TaskMutation) ClearCorrelationID() { + m.correlation_id = nil + m.clearedFields[task.FieldCorrelationID] = struct{}{} +} + +// CorrelationIDCleared returns if the "correlation_id" field was cleared in this mutation. +func (m *TaskMutation) CorrelationIDCleared() bool { + _, ok := m.clearedFields[task.FieldCorrelationID] + return ok +} + +// ResetCorrelationID resets all changes to the "correlation_id" field. +func (m *TaskMutation) ResetCorrelationID() { + m.correlation_id = nil + delete(m.clearedFields, task.FieldCorrelationID) +} + +// SetUserTasks sets the "user_tasks" field. +func (m *TaskMutation) SetUserTasks(i int) { + m.user = &i +} + +// UserTasks returns the value of the "user_tasks" field in the mutation. +func (m *TaskMutation) UserTasks() (r int, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserTasks returns the old "user_tasks" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldUserTasks(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserTasks is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserTasks requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserTasks: %w", err) + } + return oldValue.UserTasks, nil +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (m *TaskMutation) ClearUserTasks() { + m.user = nil + m.clearedFields[task.FieldUserTasks] = struct{}{} +} + +// UserTasksCleared returns if the "user_tasks" field was cleared in this mutation. +func (m *TaskMutation) UserTasksCleared() bool { + _, ok := m.clearedFields[task.FieldUserTasks] + return ok +} + +// ResetUserTasks resets all changes to the "user_tasks" field. +func (m *TaskMutation) ResetUserTasks() { + m.user = nil + delete(m.clearedFields, task.FieldUserTasks) +} + +// SetUserID sets the "user" edge to the User entity by id. +func (m *TaskMutation) SetUserID(id int) { + m.user = &id +} + +// ClearUser clears the "user" edge to the User entity. +func (m *TaskMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[task.FieldUserTasks] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *TaskMutation) UserCleared() bool { + return m.UserTasksCleared() || m.cleareduser +} + +// UserID returns the "user" edge ID in the mutation. +func (m *TaskMutation) UserID() (id int, exists bool) { + if m.user != nil { + return *m.user, true + } + return +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *TaskMutation) UserIDs() (ids []int) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *TaskMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the TaskMutation builder. +func (m *TaskMutation) Where(ps ...predicate.Task) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the TaskMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *TaskMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Task, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *TaskMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *TaskMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Task). +func (m *TaskMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *TaskMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, task.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, task.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, task.FieldDeletedAt) + } + if m._type != nil { + fields = append(fields, task.FieldType) + } + if m.status != nil { + fields = append(fields, task.FieldStatus) + } + if m.public_state != nil { + fields = append(fields, task.FieldPublicState) + } + if m.private_state != nil { + fields = append(fields, task.FieldPrivateState) + } + if m.correlation_id != nil { + fields = append(fields, task.FieldCorrelationID) + } + if m.user != nil { + fields = append(fields, task.FieldUserTasks) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *TaskMutation) Field(name string) (ent.Value, bool) { + switch name { + case task.FieldCreatedAt: + return m.CreatedAt() + case task.FieldUpdatedAt: + return m.UpdatedAt() + case task.FieldDeletedAt: + return m.DeletedAt() + case task.FieldType: + return m.GetType() + case task.FieldStatus: + return m.Status() + case task.FieldPublicState: + return m.PublicState() + case task.FieldPrivateState: + return m.PrivateState() + case task.FieldCorrelationID: + return m.CorrelationID() + case task.FieldUserTasks: + return m.UserTasks() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *TaskMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case task.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case task.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case task.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case task.FieldType: + return m.OldType(ctx) + case task.FieldStatus: + return m.OldStatus(ctx) + case task.FieldPublicState: + return m.OldPublicState(ctx) + case task.FieldPrivateState: + return m.OldPrivateState(ctx) + case task.FieldCorrelationID: + return m.OldCorrelationID(ctx) + case task.FieldUserTasks: + return m.OldUserTasks(ctx) + } + return nil, fmt.Errorf("unknown Task field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *TaskMutation) SetField(name string, value ent.Value) error { + switch name { + case task.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case task.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case task.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case task.FieldType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case task.FieldStatus: + v, ok := value.(task.Status) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case task.FieldPublicState: + v, ok := value.(*types.TaskPublicState) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPublicState(v) + return nil + case task.FieldPrivateState: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrivateState(v) + return nil + case task.FieldCorrelationID: + v, ok := value.(uuid.UUID) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCorrelationID(v) + return nil + case task.FieldUserTasks: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserTasks(v) + return nil + } + return fmt.Errorf("unknown Task field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *TaskMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *TaskMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *TaskMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown Task numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *TaskMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(task.FieldDeletedAt) { + fields = append(fields, task.FieldDeletedAt) + } + if m.FieldCleared(task.FieldPrivateState) { + fields = append(fields, task.FieldPrivateState) + } + if m.FieldCleared(task.FieldCorrelationID) { + fields = append(fields, task.FieldCorrelationID) + } + if m.FieldCleared(task.FieldUserTasks) { + fields = append(fields, task.FieldUserTasks) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *TaskMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *TaskMutation) ClearField(name string) error { + switch name { + case task.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case task.FieldPrivateState: + m.ClearPrivateState() + return nil + case task.FieldCorrelationID: + m.ClearCorrelationID() + return nil + case task.FieldUserTasks: + m.ClearUserTasks() + return nil + } + return fmt.Errorf("unknown Task nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *TaskMutation) ResetField(name string) error { + switch name { + case task.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case task.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case task.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case task.FieldType: + m.ResetType() + return nil + case task.FieldStatus: + m.ResetStatus() + return nil + case task.FieldPublicState: + m.ResetPublicState() + return nil + case task.FieldPrivateState: + m.ResetPrivateState() + return nil + case task.FieldCorrelationID: + m.ResetCorrelationID() + return nil + case task.FieldUserTasks: + m.ResetUserTasks() + return nil + } + return fmt.Errorf("unknown Task field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *TaskMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, task.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *TaskMutation) AddedIDs(name string) []ent.Value { + switch name { + case task.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *TaskMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *TaskMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *TaskMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, task.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *TaskMutation) EdgeCleared(name string) bool { + switch name { + case task.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *TaskMutation) ClearEdge(name string) error { + switch name { + case task.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown Task unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *TaskMutation) ResetEdge(name string) error { + switch name { + case task.EdgeUser: + m.ResetUser() + return nil + } + return fmt.Errorf("unknown Task edge %s", name) +} + +// UserMutation represents an operation that mutates the User nodes in the graph. +type UserMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + email *string + nick *string + password *string + status *user.Status + storage *int64 + addstorage *int64 + two_factor_secret *string + avatar *string + settings **types.UserSetting + clearedFields map[string]struct{} + group *int + clearedgroup bool + files map[int]struct{} + removedfiles map[int]struct{} + clearedfiles bool + dav_accounts map[int]struct{} + removeddav_accounts map[int]struct{} + cleareddav_accounts bool + shares map[int]struct{} + removedshares map[int]struct{} + clearedshares bool + passkey map[int]struct{} + removedpasskey map[int]struct{} + clearedpasskey bool + tasks map[int]struct{} + removedtasks map[int]struct{} + clearedtasks bool + entities map[int]struct{} + removedentities map[int]struct{} + clearedentities bool + done bool + oldValue func(context.Context) (*User, error) + predicates []predicate.User +} + +var _ ent.Mutation = (*UserMutation)(nil) + +// userOption allows management of the mutation configuration using functional options. +type userOption func(*UserMutation) + +// newUserMutation creates new mutation for the User entity. +func newUserMutation(c config, op Op, opts ...userOption) *UserMutation { + m := &UserMutation{ + config: c, + op: op, + typ: TypeUser, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUserID sets the ID field of the mutation. +func withUserID(id int) userOption { + return func(m *UserMutation) { + var ( + err error + once sync.Once + value *User + ) + m.oldValue = func(ctx context.Context) (*User, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().User.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUser sets the old User of the mutation. +func withUser(node *User) userOption { + return func(m *UserMutation) { + m.oldValue = func(context.Context) (*User, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UserMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UserMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UserMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UserMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().User.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *UserMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UserMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UserMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *UserMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *UserMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *UserMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *UserMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[user.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[user.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, user.FieldDeletedAt) +} + +// SetEmail sets the "email" field. +func (m *UserMutation) SetEmail(s string) { + m.email = &s +} + +// Email returns the value of the "email" field in the mutation. +func (m *UserMutation) Email() (r string, exists bool) { + v := m.email + if v == nil { + return + } + return *v, true +} + +// OldEmail returns the old "email" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldEmail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEmail: %w", err) + } + return oldValue.Email, nil +} + +// ResetEmail resets all changes to the "email" field. +func (m *UserMutation) ResetEmail() { + m.email = nil +} + +// SetNick sets the "nick" field. +func (m *UserMutation) SetNick(s string) { + m.nick = &s +} + +// Nick returns the value of the "nick" field in the mutation. +func (m *UserMutation) Nick() (r string, exists bool) { + v := m.nick + if v == nil { + return + } + return *v, true +} + +// OldNick returns the old "nick" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldNick(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldNick is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldNick requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldNick: %w", err) + } + return oldValue.Nick, nil +} + +// ResetNick resets all changes to the "nick" field. +func (m *UserMutation) ResetNick() { + m.nick = nil +} + +// SetPassword sets the "password" field. +func (m *UserMutation) SetPassword(s string) { + m.password = &s +} + +// Password returns the value of the "password" field in the mutation. +func (m *UserMutation) Password() (r string, exists bool) { + v := m.password + if v == nil { + return + } + return *v, true +} + +// OldPassword returns the old "password" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldPassword(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassword is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassword requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassword: %w", err) + } + return oldValue.Password, nil +} + +// ClearPassword clears the value of the "password" field. +func (m *UserMutation) ClearPassword() { + m.password = nil + m.clearedFields[user.FieldPassword] = struct{}{} +} + +// PasswordCleared returns if the "password" field was cleared in this mutation. +func (m *UserMutation) PasswordCleared() bool { + _, ok := m.clearedFields[user.FieldPassword] + return ok +} + +// ResetPassword resets all changes to the "password" field. +func (m *UserMutation) ResetPassword() { + m.password = nil + delete(m.clearedFields, user.FieldPassword) +} + +// SetStatus sets the "status" field. +func (m *UserMutation) SetStatus(u user.Status) { + m.status = &u +} + +// Status returns the value of the "status" field in the mutation. +func (m *UserMutation) Status() (r user.Status, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldStatus(ctx context.Context) (v user.Status, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *UserMutation) ResetStatus() { + m.status = nil +} + +// SetStorage sets the "storage" field. +func (m *UserMutation) SetStorage(i int64) { + m.storage = &i + m.addstorage = nil +} + +// Storage returns the value of the "storage" field in the mutation. +func (m *UserMutation) Storage() (r int64, exists bool) { + v := m.storage + if v == nil { + return + } + return *v, true +} + +// OldStorage returns the old "storage" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldStorage(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStorage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStorage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStorage: %w", err) + } + return oldValue.Storage, nil +} + +// AddStorage adds i to the "storage" field. +func (m *UserMutation) AddStorage(i int64) { + if m.addstorage != nil { + *m.addstorage += i + } else { + m.addstorage = &i + } +} + +// AddedStorage returns the value that was added to the "storage" field in this mutation. +func (m *UserMutation) AddedStorage() (r int64, exists bool) { + v := m.addstorage + if v == nil { + return + } + return *v, true +} + +// ResetStorage resets all changes to the "storage" field. +func (m *UserMutation) ResetStorage() { + m.storage = nil + m.addstorage = nil +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (m *UserMutation) SetTwoFactorSecret(s string) { + m.two_factor_secret = &s +} + +// TwoFactorSecret returns the value of the "two_factor_secret" field in the mutation. +func (m *UserMutation) TwoFactorSecret() (r string, exists bool) { + v := m.two_factor_secret + if v == nil { + return + } + return *v, true +} + +// OldTwoFactorSecret returns the old "two_factor_secret" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTwoFactorSecret(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTwoFactorSecret is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTwoFactorSecret requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTwoFactorSecret: %w", err) + } + return oldValue.TwoFactorSecret, nil +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (m *UserMutation) ClearTwoFactorSecret() { + m.two_factor_secret = nil + m.clearedFields[user.FieldTwoFactorSecret] = struct{}{} +} + +// TwoFactorSecretCleared returns if the "two_factor_secret" field was cleared in this mutation. +func (m *UserMutation) TwoFactorSecretCleared() bool { + _, ok := m.clearedFields[user.FieldTwoFactorSecret] + return ok +} + +// ResetTwoFactorSecret resets all changes to the "two_factor_secret" field. +func (m *UserMutation) ResetTwoFactorSecret() { + m.two_factor_secret = nil + delete(m.clearedFields, user.FieldTwoFactorSecret) +} + +// SetAvatar sets the "avatar" field. +func (m *UserMutation) SetAvatar(s string) { + m.avatar = &s +} + +// Avatar returns the value of the "avatar" field in the mutation. +func (m *UserMutation) Avatar() (r string, exists bool) { + v := m.avatar + if v == nil { + return + } + return *v, true +} + +// OldAvatar returns the old "avatar" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldAvatar(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAvatar is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAvatar requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAvatar: %w", err) + } + return oldValue.Avatar, nil +} + +// ClearAvatar clears the value of the "avatar" field. +func (m *UserMutation) ClearAvatar() { + m.avatar = nil + m.clearedFields[user.FieldAvatar] = struct{}{} +} + +// AvatarCleared returns if the "avatar" field was cleared in this mutation. +func (m *UserMutation) AvatarCleared() bool { + _, ok := m.clearedFields[user.FieldAvatar] + return ok +} + +// ResetAvatar resets all changes to the "avatar" field. +func (m *UserMutation) ResetAvatar() { + m.avatar = nil + delete(m.clearedFields, user.FieldAvatar) +} + +// SetSettings sets the "settings" field. +func (m *UserMutation) SetSettings(ts *types.UserSetting) { + m.settings = &ts +} + +// Settings returns the value of the "settings" field in the mutation. +func (m *UserMutation) Settings() (r *types.UserSetting, exists bool) { + v := m.settings + if v == nil { + return + } + return *v, true +} + +// OldSettings returns the old "settings" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSettings(ctx context.Context) (v *types.UserSetting, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSettings is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSettings requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSettings: %w", err) + } + return oldValue.Settings, nil +} + +// ClearSettings clears the value of the "settings" field. +func (m *UserMutation) ClearSettings() { + m.settings = nil + m.clearedFields[user.FieldSettings] = struct{}{} +} + +// SettingsCleared returns if the "settings" field was cleared in this mutation. +func (m *UserMutation) SettingsCleared() bool { + _, ok := m.clearedFields[user.FieldSettings] + return ok +} + +// ResetSettings resets all changes to the "settings" field. +func (m *UserMutation) ResetSettings() { + m.settings = nil + delete(m.clearedFields, user.FieldSettings) +} + +// SetGroupUsers sets the "group_users" field. +func (m *UserMutation) SetGroupUsers(i int) { + m.group = &i +} + +// GroupUsers returns the value of the "group_users" field in the mutation. +func (m *UserMutation) GroupUsers() (r int, exists bool) { + v := m.group + if v == nil { + return + } + return *v, true +} + +// OldGroupUsers returns the old "group_users" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldGroupUsers(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldGroupUsers is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldGroupUsers requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldGroupUsers: %w", err) + } + return oldValue.GroupUsers, nil +} + +// ResetGroupUsers resets all changes to the "group_users" field. +func (m *UserMutation) ResetGroupUsers() { + m.group = nil +} + +// SetGroupID sets the "group" edge to the Group entity by id. +func (m *UserMutation) SetGroupID(id int) { + m.group = &id +} + +// ClearGroup clears the "group" edge to the Group entity. +func (m *UserMutation) ClearGroup() { + m.clearedgroup = true + m.clearedFields[user.FieldGroupUsers] = struct{}{} +} + +// GroupCleared reports if the "group" edge to the Group entity was cleared. +func (m *UserMutation) GroupCleared() bool { + return m.clearedgroup +} + +// GroupID returns the "group" edge ID in the mutation. +func (m *UserMutation) GroupID() (id int, exists bool) { + if m.group != nil { + return *m.group, true + } + return +} + +// GroupIDs returns the "group" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// GroupID instead. It exists only for internal usage by the builders. +func (m *UserMutation) GroupIDs() (ids []int) { + if id := m.group; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetGroup resets all changes to the "group" edge. +func (m *UserMutation) ResetGroup() { + m.group = nil + m.clearedgroup = false +} + +// AddFileIDs adds the "files" edge to the File entity by ids. +func (m *UserMutation) AddFileIDs(ids ...int) { + if m.files == nil { + m.files = make(map[int]struct{}) + } + for i := range ids { + m.files[ids[i]] = struct{}{} + } +} + +// ClearFiles clears the "files" edge to the File entity. +func (m *UserMutation) ClearFiles() { + m.clearedfiles = true +} + +// FilesCleared reports if the "files" edge to the File entity was cleared. +func (m *UserMutation) FilesCleared() bool { + return m.clearedfiles +} + +// RemoveFileIDs removes the "files" edge to the File entity by IDs. +func (m *UserMutation) RemoveFileIDs(ids ...int) { + if m.removedfiles == nil { + m.removedfiles = make(map[int]struct{}) + } + for i := range ids { + delete(m.files, ids[i]) + m.removedfiles[ids[i]] = struct{}{} + } +} + +// RemovedFiles returns the removed IDs of the "files" edge to the File entity. +func (m *UserMutation) RemovedFilesIDs() (ids []int) { + for id := range m.removedfiles { + ids = append(ids, id) + } + return +} + +// FilesIDs returns the "files" edge IDs in the mutation. +func (m *UserMutation) FilesIDs() (ids []int) { + for id := range m.files { + ids = append(ids, id) + } + return +} + +// ResetFiles resets all changes to the "files" edge. +func (m *UserMutation) ResetFiles() { + m.files = nil + m.clearedfiles = false + m.removedfiles = nil +} + +// AddDavAccountIDs adds the "dav_accounts" edge to the DavAccount entity by ids. +func (m *UserMutation) AddDavAccountIDs(ids ...int) { + if m.dav_accounts == nil { + m.dav_accounts = make(map[int]struct{}) + } + for i := range ids { + m.dav_accounts[ids[i]] = struct{}{} + } +} + +// ClearDavAccounts clears the "dav_accounts" edge to the DavAccount entity. +func (m *UserMutation) ClearDavAccounts() { + m.cleareddav_accounts = true +} + +// DavAccountsCleared reports if the "dav_accounts" edge to the DavAccount entity was cleared. +func (m *UserMutation) DavAccountsCleared() bool { + return m.cleareddav_accounts +} + +// RemoveDavAccountIDs removes the "dav_accounts" edge to the DavAccount entity by IDs. +func (m *UserMutation) RemoveDavAccountIDs(ids ...int) { + if m.removeddav_accounts == nil { + m.removeddav_accounts = make(map[int]struct{}) + } + for i := range ids { + delete(m.dav_accounts, ids[i]) + m.removeddav_accounts[ids[i]] = struct{}{} + } +} + +// RemovedDavAccounts returns the removed IDs of the "dav_accounts" edge to the DavAccount entity. +func (m *UserMutation) RemovedDavAccountsIDs() (ids []int) { + for id := range m.removeddav_accounts { + ids = append(ids, id) + } + return +} + +// DavAccountsIDs returns the "dav_accounts" edge IDs in the mutation. +func (m *UserMutation) DavAccountsIDs() (ids []int) { + for id := range m.dav_accounts { + ids = append(ids, id) + } + return +} + +// ResetDavAccounts resets all changes to the "dav_accounts" edge. +func (m *UserMutation) ResetDavAccounts() { + m.dav_accounts = nil + m.cleareddav_accounts = false + m.removeddav_accounts = nil +} + +// AddShareIDs adds the "shares" edge to the Share entity by ids. +func (m *UserMutation) AddShareIDs(ids ...int) { + if m.shares == nil { + m.shares = make(map[int]struct{}) + } + for i := range ids { + m.shares[ids[i]] = struct{}{} + } +} + +// ClearShares clears the "shares" edge to the Share entity. +func (m *UserMutation) ClearShares() { + m.clearedshares = true +} + +// SharesCleared reports if the "shares" edge to the Share entity was cleared. +func (m *UserMutation) SharesCleared() bool { + return m.clearedshares +} + +// RemoveShareIDs removes the "shares" edge to the Share entity by IDs. +func (m *UserMutation) RemoveShareIDs(ids ...int) { + if m.removedshares == nil { + m.removedshares = make(map[int]struct{}) + } + for i := range ids { + delete(m.shares, ids[i]) + m.removedshares[ids[i]] = struct{}{} + } +} + +// RemovedShares returns the removed IDs of the "shares" edge to the Share entity. +func (m *UserMutation) RemovedSharesIDs() (ids []int) { + for id := range m.removedshares { + ids = append(ids, id) + } + return +} + +// SharesIDs returns the "shares" edge IDs in the mutation. +func (m *UserMutation) SharesIDs() (ids []int) { + for id := range m.shares { + ids = append(ids, id) + } + return +} + +// ResetShares resets all changes to the "shares" edge. +func (m *UserMutation) ResetShares() { + m.shares = nil + m.clearedshares = false + m.removedshares = nil +} + +// AddPasskeyIDs adds the "passkey" edge to the Passkey entity by ids. +func (m *UserMutation) AddPasskeyIDs(ids ...int) { + if m.passkey == nil { + m.passkey = make(map[int]struct{}) + } + for i := range ids { + m.passkey[ids[i]] = struct{}{} + } +} + +// ClearPasskey clears the "passkey" edge to the Passkey entity. +func (m *UserMutation) ClearPasskey() { + m.clearedpasskey = true +} + +// PasskeyCleared reports if the "passkey" edge to the Passkey entity was cleared. +func (m *UserMutation) PasskeyCleared() bool { + return m.clearedpasskey +} + +// RemovePasskeyIDs removes the "passkey" edge to the Passkey entity by IDs. +func (m *UserMutation) RemovePasskeyIDs(ids ...int) { + if m.removedpasskey == nil { + m.removedpasskey = make(map[int]struct{}) + } + for i := range ids { + delete(m.passkey, ids[i]) + m.removedpasskey[ids[i]] = struct{}{} + } +} + +// RemovedPasskey returns the removed IDs of the "passkey" edge to the Passkey entity. +func (m *UserMutation) RemovedPasskeyIDs() (ids []int) { + for id := range m.removedpasskey { + ids = append(ids, id) + } + return +} + +// PasskeyIDs returns the "passkey" edge IDs in the mutation. +func (m *UserMutation) PasskeyIDs() (ids []int) { + for id := range m.passkey { + ids = append(ids, id) + } + return +} + +// ResetPasskey resets all changes to the "passkey" edge. +func (m *UserMutation) ResetPasskey() { + m.passkey = nil + m.clearedpasskey = false + m.removedpasskey = nil +} + +// AddTaskIDs adds the "tasks" edge to the Task entity by ids. +func (m *UserMutation) AddTaskIDs(ids ...int) { + if m.tasks == nil { + m.tasks = make(map[int]struct{}) + } + for i := range ids { + m.tasks[ids[i]] = struct{}{} + } +} + +// ClearTasks clears the "tasks" edge to the Task entity. +func (m *UserMutation) ClearTasks() { + m.clearedtasks = true +} + +// TasksCleared reports if the "tasks" edge to the Task entity was cleared. +func (m *UserMutation) TasksCleared() bool { + return m.clearedtasks +} + +// RemoveTaskIDs removes the "tasks" edge to the Task entity by IDs. +func (m *UserMutation) RemoveTaskIDs(ids ...int) { + if m.removedtasks == nil { + m.removedtasks = make(map[int]struct{}) + } + for i := range ids { + delete(m.tasks, ids[i]) + m.removedtasks[ids[i]] = struct{}{} + } +} + +// RemovedTasks returns the removed IDs of the "tasks" edge to the Task entity. +func (m *UserMutation) RemovedTasksIDs() (ids []int) { + for id := range m.removedtasks { + ids = append(ids, id) + } + return +} + +// TasksIDs returns the "tasks" edge IDs in the mutation. +func (m *UserMutation) TasksIDs() (ids []int) { + for id := range m.tasks { + ids = append(ids, id) + } + return +} + +// ResetTasks resets all changes to the "tasks" edge. +func (m *UserMutation) ResetTasks() { + m.tasks = nil + m.clearedtasks = false + m.removedtasks = nil +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by ids. +func (m *UserMutation) AddEntityIDs(ids ...int) { + if m.entities == nil { + m.entities = make(map[int]struct{}) + } + for i := range ids { + m.entities[ids[i]] = struct{}{} + } +} + +// ClearEntities clears the "entities" edge to the Entity entity. +func (m *UserMutation) ClearEntities() { + m.clearedentities = true +} + +// EntitiesCleared reports if the "entities" edge to the Entity entity was cleared. +func (m *UserMutation) EntitiesCleared() bool { + return m.clearedentities +} + +// RemoveEntityIDs removes the "entities" edge to the Entity entity by IDs. +func (m *UserMutation) RemoveEntityIDs(ids ...int) { + if m.removedentities == nil { + m.removedentities = make(map[int]struct{}) + } + for i := range ids { + delete(m.entities, ids[i]) + m.removedentities[ids[i]] = struct{}{} + } +} + +// RemovedEntities returns the removed IDs of the "entities" edge to the Entity entity. +func (m *UserMutation) RemovedEntitiesIDs() (ids []int) { + for id := range m.removedentities { + ids = append(ids, id) + } + return +} + +// EntitiesIDs returns the "entities" edge IDs in the mutation. +func (m *UserMutation) EntitiesIDs() (ids []int) { + for id := range m.entities { + ids = append(ids, id) + } + return +} + +// ResetEntities resets all changes to the "entities" edge. +func (m *UserMutation) ResetEntities() { + m.entities = nil + m.clearedentities = false + m.removedentities = nil +} + +// Where appends a list predicates to the UserMutation builder. +func (m *UserMutation) Where(ps ...predicate.User) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UserMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UserMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.User, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UserMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UserMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (User). +func (m *UserMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UserMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.created_at != nil { + fields = append(fields, user.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, user.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, user.FieldDeletedAt) + } + if m.email != nil { + fields = append(fields, user.FieldEmail) + } + if m.nick != nil { + fields = append(fields, user.FieldNick) + } + if m.password != nil { + fields = append(fields, user.FieldPassword) + } + if m.status != nil { + fields = append(fields, user.FieldStatus) + } + if m.storage != nil { + fields = append(fields, user.FieldStorage) + } + if m.two_factor_secret != nil { + fields = append(fields, user.FieldTwoFactorSecret) + } + if m.avatar != nil { + fields = append(fields, user.FieldAvatar) + } + if m.settings != nil { + fields = append(fields, user.FieldSettings) + } + if m.group != nil { + fields = append(fields, user.FieldGroupUsers) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UserMutation) Field(name string) (ent.Value, bool) { + switch name { + case user.FieldCreatedAt: + return m.CreatedAt() + case user.FieldUpdatedAt: + return m.UpdatedAt() + case user.FieldDeletedAt: + return m.DeletedAt() + case user.FieldEmail: + return m.Email() + case user.FieldNick: + return m.Nick() + case user.FieldPassword: + return m.Password() + case user.FieldStatus: + return m.Status() + case user.FieldStorage: + return m.Storage() + case user.FieldTwoFactorSecret: + return m.TwoFactorSecret() + case user.FieldAvatar: + return m.Avatar() + case user.FieldSettings: + return m.Settings() + case user.FieldGroupUsers: + return m.GroupUsers() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case user.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case user.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case user.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case user.FieldEmail: + return m.OldEmail(ctx) + case user.FieldNick: + return m.OldNick(ctx) + case user.FieldPassword: + return m.OldPassword(ctx) + case user.FieldStatus: + return m.OldStatus(ctx) + case user.FieldStorage: + return m.OldStorage(ctx) + case user.FieldTwoFactorSecret: + return m.OldTwoFactorSecret(ctx) + case user.FieldAvatar: + return m.OldAvatar(ctx) + case user.FieldSettings: + return m.OldSettings(ctx) + case user.FieldGroupUsers: + return m.OldGroupUsers(ctx) + } + return nil, fmt.Errorf("unknown User field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserMutation) SetField(name string, value ent.Value) error { + switch name { + case user.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case user.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case user.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case user.FieldEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEmail(v) + return nil + case user.FieldNick: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetNick(v) + return nil + case user.FieldPassword: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassword(v) + return nil + case user.FieldStatus: + v, ok := value.(user.Status) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case user.FieldStorage: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStorage(v) + return nil + case user.FieldTwoFactorSecret: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTwoFactorSecret(v) + return nil + case user.FieldAvatar: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAvatar(v) + return nil + case user.FieldSettings: + v, ok := value.(*types.UserSetting) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSettings(v) + return nil + case user.FieldGroupUsers: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetGroupUsers(v) + return nil + } + return fmt.Errorf("unknown User field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UserMutation) AddedFields() []string { + var fields []string + if m.addstorage != nil { + fields = append(fields, user.FieldStorage) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UserMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case user.FieldStorage: + return m.AddedStorage() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserMutation) AddField(name string, value ent.Value) error { + switch name { + case user.FieldStorage: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddStorage(v) + return nil + } + return fmt.Errorf("unknown User numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UserMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(user.FieldDeletedAt) { + fields = append(fields, user.FieldDeletedAt) + } + if m.FieldCleared(user.FieldPassword) { + fields = append(fields, user.FieldPassword) + } + if m.FieldCleared(user.FieldTwoFactorSecret) { + fields = append(fields, user.FieldTwoFactorSecret) + } + if m.FieldCleared(user.FieldAvatar) { + fields = append(fields, user.FieldAvatar) + } + if m.FieldCleared(user.FieldSettings) { + fields = append(fields, user.FieldSettings) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UserMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UserMutation) ClearField(name string) error { + switch name { + case user.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case user.FieldPassword: + m.ClearPassword() + return nil + case user.FieldTwoFactorSecret: + m.ClearTwoFactorSecret() + return nil + case user.FieldAvatar: + m.ClearAvatar() + return nil + case user.FieldSettings: + m.ClearSettings() + return nil + } + return fmt.Errorf("unknown User nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UserMutation) ResetField(name string) error { + switch name { + case user.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case user.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case user.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case user.FieldEmail: + m.ResetEmail() + return nil + case user.FieldNick: + m.ResetNick() + return nil + case user.FieldPassword: + m.ResetPassword() + return nil + case user.FieldStatus: + m.ResetStatus() + return nil + case user.FieldStorage: + m.ResetStorage() + return nil + case user.FieldTwoFactorSecret: + m.ResetTwoFactorSecret() + return nil + case user.FieldAvatar: + m.ResetAvatar() + return nil + case user.FieldSettings: + m.ResetSettings() + return nil + case user.FieldGroupUsers: + m.ResetGroupUsers() + return nil + } + return fmt.Errorf("unknown User field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UserMutation) AddedEdges() []string { + edges := make([]string, 0, 7) + if m.group != nil { + edges = append(edges, user.EdgeGroup) + } + if m.files != nil { + edges = append(edges, user.EdgeFiles) + } + if m.dav_accounts != nil { + edges = append(edges, user.EdgeDavAccounts) + } + if m.shares != nil { + edges = append(edges, user.EdgeShares) + } + if m.passkey != nil { + edges = append(edges, user.EdgePasskey) + } + if m.tasks != nil { + edges = append(edges, user.EdgeTasks) + } + if m.entities != nil { + edges = append(edges, user.EdgeEntities) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UserMutation) AddedIDs(name string) []ent.Value { + switch name { + case user.EdgeGroup: + if id := m.group; id != nil { + return []ent.Value{*id} + } + case user.EdgeFiles: + ids := make([]ent.Value, 0, len(m.files)) + for id := range m.files { + ids = append(ids, id) + } + return ids + case user.EdgeDavAccounts: + ids := make([]ent.Value, 0, len(m.dav_accounts)) + for id := range m.dav_accounts { + ids = append(ids, id) + } + return ids + case user.EdgeShares: + ids := make([]ent.Value, 0, len(m.shares)) + for id := range m.shares { + ids = append(ids, id) + } + return ids + case user.EdgePasskey: + ids := make([]ent.Value, 0, len(m.passkey)) + for id := range m.passkey { + ids = append(ids, id) + } + return ids + case user.EdgeTasks: + ids := make([]ent.Value, 0, len(m.tasks)) + for id := range m.tasks { + ids = append(ids, id) + } + return ids + case user.EdgeEntities: + ids := make([]ent.Value, 0, len(m.entities)) + for id := range m.entities { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UserMutation) RemovedEdges() []string { + edges := make([]string, 0, 7) + if m.removedfiles != nil { + edges = append(edges, user.EdgeFiles) + } + if m.removeddav_accounts != nil { + edges = append(edges, user.EdgeDavAccounts) + } + if m.removedshares != nil { + edges = append(edges, user.EdgeShares) + } + if m.removedpasskey != nil { + edges = append(edges, user.EdgePasskey) + } + if m.removedtasks != nil { + edges = append(edges, user.EdgeTasks) + } + if m.removedentities != nil { + edges = append(edges, user.EdgeEntities) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UserMutation) RemovedIDs(name string) []ent.Value { + switch name { + case user.EdgeFiles: + ids := make([]ent.Value, 0, len(m.removedfiles)) + for id := range m.removedfiles { + ids = append(ids, id) + } + return ids + case user.EdgeDavAccounts: + ids := make([]ent.Value, 0, len(m.removeddav_accounts)) + for id := range m.removeddav_accounts { + ids = append(ids, id) + } + return ids + case user.EdgeShares: + ids := make([]ent.Value, 0, len(m.removedshares)) + for id := range m.removedshares { + ids = append(ids, id) + } + return ids + case user.EdgePasskey: + ids := make([]ent.Value, 0, len(m.removedpasskey)) + for id := range m.removedpasskey { + ids = append(ids, id) + } + return ids + case user.EdgeTasks: + ids := make([]ent.Value, 0, len(m.removedtasks)) + for id := range m.removedtasks { + ids = append(ids, id) + } + return ids + case user.EdgeEntities: + ids := make([]ent.Value, 0, len(m.removedentities)) + for id := range m.removedentities { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UserMutation) ClearedEdges() []string { + edges := make([]string, 0, 7) + if m.clearedgroup { + edges = append(edges, user.EdgeGroup) + } + if m.clearedfiles { + edges = append(edges, user.EdgeFiles) + } + if m.cleareddav_accounts { + edges = append(edges, user.EdgeDavAccounts) + } + if m.clearedshares { + edges = append(edges, user.EdgeShares) + } + if m.clearedpasskey { + edges = append(edges, user.EdgePasskey) + } + if m.clearedtasks { + edges = append(edges, user.EdgeTasks) + } + if m.clearedentities { + edges = append(edges, user.EdgeEntities) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UserMutation) EdgeCleared(name string) bool { + switch name { + case user.EdgeGroup: + return m.clearedgroup + case user.EdgeFiles: + return m.clearedfiles + case user.EdgeDavAccounts: + return m.cleareddav_accounts + case user.EdgeShares: + return m.clearedshares + case user.EdgePasskey: + return m.clearedpasskey + case user.EdgeTasks: + return m.clearedtasks + case user.EdgeEntities: + return m.clearedentities + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UserMutation) ClearEdge(name string) error { + switch name { + case user.EdgeGroup: + m.ClearGroup() + return nil + } + return fmt.Errorf("unknown User unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UserMutation) ResetEdge(name string) error { + switch name { + case user.EdgeGroup: + m.ResetGroup() + return nil + case user.EdgeFiles: + m.ResetFiles() + return nil + case user.EdgeDavAccounts: + m.ResetDavAccounts() + return nil + case user.EdgeShares: + m.ResetShares() + return nil + case user.EdgePasskey: + m.ResetPasskey() + return nil + case user.EdgeTasks: + m.ResetTasks() + return nil + case user.EdgeEntities: + m.ResetEntities() + return nil + } + return fmt.Errorf("unknown User edge %s", name) +} diff --git a/ent/mutationhelper.go b/ent/mutationhelper.go new file mode 100644 index 00000000..86d1931f --- /dev/null +++ b/ent/mutationhelper.go @@ -0,0 +1,81 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +// SetUpdatedAt sets the "updated_at" field. + +func (m *DavAccountMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *DirectLinkMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *EntityMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *FileMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *GroupMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *MetadataMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *NodeMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *PasskeyMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *SettingMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *ShareMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *StoragePolicyMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *TaskMutation) SetRawID(t int) { + m.id = &t +} + +// SetUpdatedAt sets the "updated_at" field. + +func (m *UserMutation) SetRawID(t int) { + m.id = &t +} diff --git a/ent/node.go b/ent/node.go new file mode 100644 index 00000000..d8ac30d0 --- /dev/null +++ b/ent/node.go @@ -0,0 +1,260 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// Node is the model entity for the Node schema. +type Node struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Status holds the value of the "status" field. + Status node.Status `json:"status,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Type holds the value of the "type" field. + Type node.Type `json:"type,omitempty"` + // Server holds the value of the "server" field. + Server string `json:"server,omitempty"` + // SlaveKey holds the value of the "slave_key" field. + SlaveKey string `json:"slave_key,omitempty"` + // Capabilities holds the value of the "capabilities" field. + Capabilities *boolset.BooleanSet `json:"capabilities,omitempty"` + // Settings holds the value of the "settings" field. + Settings *types.NodeSetting `json:"settings,omitempty"` + // Weight holds the value of the "weight" field. + Weight int `json:"weight,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the NodeQuery when eager-loading is set. + Edges NodeEdges `json:"edges"` + selectValues sql.SelectValues +} + +// NodeEdges holds the relations/edges for other nodes in the graph. +type NodeEdges struct { + // StoragePolicy holds the value of the storage_policy edge. + StoragePolicy []*StoragePolicy `json:"storage_policy,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// StoragePolicyOrErr returns the StoragePolicy value or an error if the edge +// was not loaded in eager-loading. +func (e NodeEdges) StoragePolicyOrErr() ([]*StoragePolicy, error) { + if e.loadedTypes[0] { + return e.StoragePolicy, nil + } + return nil, &NotLoadedError{edge: "storage_policy"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Node) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case node.FieldSettings: + values[i] = new([]byte) + case node.FieldCapabilities: + values[i] = new(boolset.BooleanSet) + case node.FieldID, node.FieldWeight: + values[i] = new(sql.NullInt64) + case node.FieldStatus, node.FieldName, node.FieldType, node.FieldServer, node.FieldSlaveKey: + values[i] = new(sql.NullString) + case node.FieldCreatedAt, node.FieldUpdatedAt, node.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Node fields. +func (n *Node) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case node.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + n.ID = int(value.Int64) + case node.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + n.CreatedAt = value.Time + } + case node.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + n.UpdatedAt = value.Time + } + case node.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + n.DeletedAt = new(time.Time) + *n.DeletedAt = value.Time + } + case node.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + n.Status = node.Status(value.String) + } + case node.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + n.Name = value.String + } + case node.FieldType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + n.Type = node.Type(value.String) + } + case node.FieldServer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field server", values[i]) + } else if value.Valid { + n.Server = value.String + } + case node.FieldSlaveKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field slave_key", values[i]) + } else if value.Valid { + n.SlaveKey = value.String + } + case node.FieldCapabilities: + if value, ok := values[i].(*boolset.BooleanSet); !ok { + return fmt.Errorf("unexpected type %T for field capabilities", values[i]) + } else if value != nil { + n.Capabilities = value + } + case node.FieldSettings: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field settings", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &n.Settings); err != nil { + return fmt.Errorf("unmarshal field settings: %w", err) + } + } + case node.FieldWeight: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field weight", values[i]) + } else if value.Valid { + n.Weight = int(value.Int64) + } + default: + n.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Node. +// This includes values selected through modifiers, order, etc. +func (n *Node) Value(name string) (ent.Value, error) { + return n.selectValues.Get(name) +} + +// QueryStoragePolicy queries the "storage_policy" edge of the Node entity. +func (n *Node) QueryStoragePolicy() *StoragePolicyQuery { + return NewNodeClient(n.config).QueryStoragePolicy(n) +} + +// Update returns a builder for updating this Node. +// Note that you need to call Node.Unwrap() before calling this method if this Node +// was returned from a transaction, and the transaction was committed or rolled back. +func (n *Node) Update() *NodeUpdateOne { + return NewNodeClient(n.config).UpdateOne(n) +} + +// Unwrap unwraps the Node entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (n *Node) Unwrap() *Node { + _tx, ok := n.config.driver.(*txDriver) + if !ok { + panic("ent: Node is not a transactional entity") + } + n.config.driver = _tx.drv + return n +} + +// String implements the fmt.Stringer. +func (n *Node) String() string { + var builder strings.Builder + builder.WriteString("Node(") + builder.WriteString(fmt.Sprintf("id=%v, ", n.ID)) + builder.WriteString("created_at=") + builder.WriteString(n.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(n.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := n.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", n.Status)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(n.Name) + builder.WriteString(", ") + builder.WriteString("type=") + builder.WriteString(fmt.Sprintf("%v", n.Type)) + builder.WriteString(", ") + builder.WriteString("server=") + builder.WriteString(n.Server) + builder.WriteString(", ") + builder.WriteString("slave_key=") + builder.WriteString(n.SlaveKey) + builder.WriteString(", ") + builder.WriteString("capabilities=") + builder.WriteString(fmt.Sprintf("%v", n.Capabilities)) + builder.WriteString(", ") + builder.WriteString("settings=") + builder.WriteString(fmt.Sprintf("%v", n.Settings)) + builder.WriteString(", ") + builder.WriteString("weight=") + builder.WriteString(fmt.Sprintf("%v", n.Weight)) + builder.WriteByte(')') + return builder.String() +} + +// SetStoragePolicy manually set the edge as loaded state. +func (e *Node) SetStoragePolicy(v []*StoragePolicy) { + e.Edges.StoragePolicy = v + e.Edges.loadedTypes[0] = true +} + +// Nodes is a parsable slice of Node. +type Nodes []*Node diff --git a/ent/node/node.go b/ent/node/node.go new file mode 100644 index 00000000..8d2c64aa --- /dev/null +++ b/ent/node/node.go @@ -0,0 +1,219 @@ +// Code generated by ent, DO NOT EDIT. + +package node + +import ( + "fmt" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +const ( + // Label holds the string label denoting the node type in the database. + Label = "node" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldServer holds the string denoting the server field in the database. + FieldServer = "server" + // FieldSlaveKey holds the string denoting the slave_key field in the database. + FieldSlaveKey = "slave_key" + // FieldCapabilities holds the string denoting the capabilities field in the database. + FieldCapabilities = "capabilities" + // FieldSettings holds the string denoting the settings field in the database. + FieldSettings = "settings" + // FieldWeight holds the string denoting the weight field in the database. + FieldWeight = "weight" + // EdgeStoragePolicy holds the string denoting the storage_policy edge name in mutations. + EdgeStoragePolicy = "storage_policy" + // Table holds the table name of the node in the database. + Table = "nodes" + // StoragePolicyTable is the table that holds the storage_policy relation/edge. + StoragePolicyTable = "storage_policies" + // StoragePolicyInverseTable is the table name for the StoragePolicy entity. + // It exists in this package in order to avoid circular dependency with the "storagepolicy" package. + StoragePolicyInverseTable = "storage_policies" + // StoragePolicyColumn is the table column denoting the storage_policy relation/edge. + StoragePolicyColumn = "node_id" +) + +// Columns holds all SQL columns for node fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldStatus, + FieldName, + FieldType, + FieldServer, + FieldSlaveKey, + FieldCapabilities, + FieldSettings, + FieldWeight, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultSettings holds the default value on creation for the "settings" field. + DefaultSettings *types.NodeSetting + // DefaultWeight holds the default value on creation for the "weight" field. + DefaultWeight int +) + +// Status defines the type for the "status" enum field. +type Status string + +// Status values. +const ( + StatusActive Status = "active" + StatusSuspended Status = "suspended" +) + +func (s Status) String() string { + return string(s) +} + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s Status) error { + switch s { + case StatusActive, StatusSuspended: + return nil + default: + return fmt.Errorf("node: invalid enum value for status field: %q", s) + } +} + +// Type defines the type for the "type" enum field. +type Type string + +// Type values. +const ( + TypeMaster Type = "master" + TypeSlave Type = "slave" +) + +func (_type Type) String() string { + return string(_type) +} + +// TypeValidator is a validator for the "type" field enum values. It is called by the builders before save. +func TypeValidator(_type Type) error { + switch _type { + case TypeMaster, TypeSlave: + return nil + default: + return fmt.Errorf("node: invalid enum value for type field: %q", _type) + } +} + +// OrderOption defines the ordering options for the Node queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByServer orders the results by the server field. +func ByServer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldServer, opts...).ToFunc() +} + +// BySlaveKey orders the results by the slave_key field. +func BySlaveKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSlaveKey, opts...).ToFunc() +} + +// ByWeight orders the results by the weight field. +func ByWeight(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeight, opts...).ToFunc() +} + +// ByStoragePolicyCount orders the results by storage_policy count. +func ByStoragePolicyCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newStoragePolicyStep(), opts...) + } +} + +// ByStoragePolicy orders the results by storage_policy terms. +func ByStoragePolicy(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStoragePolicyStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newStoragePolicyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StoragePolicyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, StoragePolicyTable, StoragePolicyColumn), + ) +} diff --git a/ent/node/where.go b/ent/node/where.go new file mode 100644 index 00000000..9298ef83 --- /dev/null +++ b/ent/node/where.go @@ -0,0 +1,610 @@ +// Code generated by ent, DO NOT EDIT. + +package node + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Node { + return predicate.Node(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Node { + return predicate.Node(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Node { + return predicate.Node(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldName, v)) +} + +// Server applies equality check predicate on the "server" field. It's identical to ServerEQ. +func Server(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldServer, v)) +} + +// SlaveKey applies equality check predicate on the "slave_key" field. It's identical to SlaveKeyEQ. +func SlaveKey(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldSlaveKey, v)) +} + +// Capabilities applies equality check predicate on the "capabilities" field. It's identical to CapabilitiesEQ. +func Capabilities(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCapabilities, v)) +} + +// Weight applies equality check predicate on the "weight" field. It's identical to WeightEQ. +func Weight(v int) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldWeight, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldDeletedAt)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v Status) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v Status) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...Status) predicate.Node { + return predicate.Node(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...Status) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldStatus, vs...)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Node { + return predicate.Node(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Node { + return predicate.Node(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Node { + return predicate.Node(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Node { + return predicate.Node(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Node { + return predicate.Node(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Node { + return predicate.Node(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Node { + return predicate.Node(sql.FieldContainsFold(FieldName, v)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v Type) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v Type) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...Type) predicate.Node { + return predicate.Node(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...Type) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldType, vs...)) +} + +// ServerEQ applies the EQ predicate on the "server" field. +func ServerEQ(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldServer, v)) +} + +// ServerNEQ applies the NEQ predicate on the "server" field. +func ServerNEQ(v string) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldServer, v)) +} + +// ServerIn applies the In predicate on the "server" field. +func ServerIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldIn(FieldServer, vs...)) +} + +// ServerNotIn applies the NotIn predicate on the "server" field. +func ServerNotIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldServer, vs...)) +} + +// ServerGT applies the GT predicate on the "server" field. +func ServerGT(v string) predicate.Node { + return predicate.Node(sql.FieldGT(FieldServer, v)) +} + +// ServerGTE applies the GTE predicate on the "server" field. +func ServerGTE(v string) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldServer, v)) +} + +// ServerLT applies the LT predicate on the "server" field. +func ServerLT(v string) predicate.Node { + return predicate.Node(sql.FieldLT(FieldServer, v)) +} + +// ServerLTE applies the LTE predicate on the "server" field. +func ServerLTE(v string) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldServer, v)) +} + +// ServerContains applies the Contains predicate on the "server" field. +func ServerContains(v string) predicate.Node { + return predicate.Node(sql.FieldContains(FieldServer, v)) +} + +// ServerHasPrefix applies the HasPrefix predicate on the "server" field. +func ServerHasPrefix(v string) predicate.Node { + return predicate.Node(sql.FieldHasPrefix(FieldServer, v)) +} + +// ServerHasSuffix applies the HasSuffix predicate on the "server" field. +func ServerHasSuffix(v string) predicate.Node { + return predicate.Node(sql.FieldHasSuffix(FieldServer, v)) +} + +// ServerIsNil applies the IsNil predicate on the "server" field. +func ServerIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldServer)) +} + +// ServerNotNil applies the NotNil predicate on the "server" field. +func ServerNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldServer)) +} + +// ServerEqualFold applies the EqualFold predicate on the "server" field. +func ServerEqualFold(v string) predicate.Node { + return predicate.Node(sql.FieldEqualFold(FieldServer, v)) +} + +// ServerContainsFold applies the ContainsFold predicate on the "server" field. +func ServerContainsFold(v string) predicate.Node { + return predicate.Node(sql.FieldContainsFold(FieldServer, v)) +} + +// SlaveKeyEQ applies the EQ predicate on the "slave_key" field. +func SlaveKeyEQ(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldSlaveKey, v)) +} + +// SlaveKeyNEQ applies the NEQ predicate on the "slave_key" field. +func SlaveKeyNEQ(v string) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldSlaveKey, v)) +} + +// SlaveKeyIn applies the In predicate on the "slave_key" field. +func SlaveKeyIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldIn(FieldSlaveKey, vs...)) +} + +// SlaveKeyNotIn applies the NotIn predicate on the "slave_key" field. +func SlaveKeyNotIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldSlaveKey, vs...)) +} + +// SlaveKeyGT applies the GT predicate on the "slave_key" field. +func SlaveKeyGT(v string) predicate.Node { + return predicate.Node(sql.FieldGT(FieldSlaveKey, v)) +} + +// SlaveKeyGTE applies the GTE predicate on the "slave_key" field. +func SlaveKeyGTE(v string) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldSlaveKey, v)) +} + +// SlaveKeyLT applies the LT predicate on the "slave_key" field. +func SlaveKeyLT(v string) predicate.Node { + return predicate.Node(sql.FieldLT(FieldSlaveKey, v)) +} + +// SlaveKeyLTE applies the LTE predicate on the "slave_key" field. +func SlaveKeyLTE(v string) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldSlaveKey, v)) +} + +// SlaveKeyContains applies the Contains predicate on the "slave_key" field. +func SlaveKeyContains(v string) predicate.Node { + return predicate.Node(sql.FieldContains(FieldSlaveKey, v)) +} + +// SlaveKeyHasPrefix applies the HasPrefix predicate on the "slave_key" field. +func SlaveKeyHasPrefix(v string) predicate.Node { + return predicate.Node(sql.FieldHasPrefix(FieldSlaveKey, v)) +} + +// SlaveKeyHasSuffix applies the HasSuffix predicate on the "slave_key" field. +func SlaveKeyHasSuffix(v string) predicate.Node { + return predicate.Node(sql.FieldHasSuffix(FieldSlaveKey, v)) +} + +// SlaveKeyIsNil applies the IsNil predicate on the "slave_key" field. +func SlaveKeyIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldSlaveKey)) +} + +// SlaveKeyNotNil applies the NotNil predicate on the "slave_key" field. +func SlaveKeyNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldSlaveKey)) +} + +// SlaveKeyEqualFold applies the EqualFold predicate on the "slave_key" field. +func SlaveKeyEqualFold(v string) predicate.Node { + return predicate.Node(sql.FieldEqualFold(FieldSlaveKey, v)) +} + +// SlaveKeyContainsFold applies the ContainsFold predicate on the "slave_key" field. +func SlaveKeyContainsFold(v string) predicate.Node { + return predicate.Node(sql.FieldContainsFold(FieldSlaveKey, v)) +} + +// CapabilitiesEQ applies the EQ predicate on the "capabilities" field. +func CapabilitiesEQ(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCapabilities, v)) +} + +// CapabilitiesNEQ applies the NEQ predicate on the "capabilities" field. +func CapabilitiesNEQ(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldCapabilities, v)) +} + +// CapabilitiesIn applies the In predicate on the "capabilities" field. +func CapabilitiesIn(vs ...*boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldIn(FieldCapabilities, vs...)) +} + +// CapabilitiesNotIn applies the NotIn predicate on the "capabilities" field. +func CapabilitiesNotIn(vs ...*boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldCapabilities, vs...)) +} + +// CapabilitiesGT applies the GT predicate on the "capabilities" field. +func CapabilitiesGT(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldGT(FieldCapabilities, v)) +} + +// CapabilitiesGTE applies the GTE predicate on the "capabilities" field. +func CapabilitiesGTE(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldCapabilities, v)) +} + +// CapabilitiesLT applies the LT predicate on the "capabilities" field. +func CapabilitiesLT(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldLT(FieldCapabilities, v)) +} + +// CapabilitiesLTE applies the LTE predicate on the "capabilities" field. +func CapabilitiesLTE(v *boolset.BooleanSet) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldCapabilities, v)) +} + +// SettingsIsNil applies the IsNil predicate on the "settings" field. +func SettingsIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldSettings)) +} + +// SettingsNotNil applies the NotNil predicate on the "settings" field. +func SettingsNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldSettings)) +} + +// WeightEQ applies the EQ predicate on the "weight" field. +func WeightEQ(v int) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldWeight, v)) +} + +// WeightNEQ applies the NEQ predicate on the "weight" field. +func WeightNEQ(v int) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldWeight, v)) +} + +// WeightIn applies the In predicate on the "weight" field. +func WeightIn(vs ...int) predicate.Node { + return predicate.Node(sql.FieldIn(FieldWeight, vs...)) +} + +// WeightNotIn applies the NotIn predicate on the "weight" field. +func WeightNotIn(vs ...int) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldWeight, vs...)) +} + +// WeightGT applies the GT predicate on the "weight" field. +func WeightGT(v int) predicate.Node { + return predicate.Node(sql.FieldGT(FieldWeight, v)) +} + +// WeightGTE applies the GTE predicate on the "weight" field. +func WeightGTE(v int) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldWeight, v)) +} + +// WeightLT applies the LT predicate on the "weight" field. +func WeightLT(v int) predicate.Node { + return predicate.Node(sql.FieldLT(FieldWeight, v)) +} + +// WeightLTE applies the LTE predicate on the "weight" field. +func WeightLTE(v int) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldWeight, v)) +} + +// HasStoragePolicy applies the HasEdge predicate on the "storage_policy" edge. +func HasStoragePolicy() predicate.Node { + return predicate.Node(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, StoragePolicyTable, StoragePolicyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasStoragePolicyWith applies the HasEdge predicate on the "storage_policy" edge with a given conditions (other predicates). +func HasStoragePolicyWith(preds ...predicate.StoragePolicy) predicate.Node { + return predicate.Node(func(s *sql.Selector) { + step := newStoragePolicyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Node) predicate.Node { + return predicate.Node(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Node) predicate.Node { + return predicate.Node(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Node) predicate.Node { + return predicate.Node(sql.NotPredicates(p)) +} diff --git a/ent/node_create.go b/ent/node_create.go new file mode 100644 index 00000000..65076a73 --- /dev/null +++ b/ent/node_create.go @@ -0,0 +1,1180 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// NodeCreate is the builder for creating a Node entity. +type NodeCreate struct { + config + mutation *NodeMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (nc *NodeCreate) SetCreatedAt(t time.Time) *NodeCreate { + nc.mutation.SetCreatedAt(t) + return nc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (nc *NodeCreate) SetNillableCreatedAt(t *time.Time) *NodeCreate { + if t != nil { + nc.SetCreatedAt(*t) + } + return nc +} + +// SetUpdatedAt sets the "updated_at" field. +func (nc *NodeCreate) SetUpdatedAt(t time.Time) *NodeCreate { + nc.mutation.SetUpdatedAt(t) + return nc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (nc *NodeCreate) SetNillableUpdatedAt(t *time.Time) *NodeCreate { + if t != nil { + nc.SetUpdatedAt(*t) + } + return nc +} + +// SetDeletedAt sets the "deleted_at" field. +func (nc *NodeCreate) SetDeletedAt(t time.Time) *NodeCreate { + nc.mutation.SetDeletedAt(t) + return nc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (nc *NodeCreate) SetNillableDeletedAt(t *time.Time) *NodeCreate { + if t != nil { + nc.SetDeletedAt(*t) + } + return nc +} + +// SetStatus sets the "status" field. +func (nc *NodeCreate) SetStatus(n node.Status) *NodeCreate { + nc.mutation.SetStatus(n) + return nc +} + +// SetName sets the "name" field. +func (nc *NodeCreate) SetName(s string) *NodeCreate { + nc.mutation.SetName(s) + return nc +} + +// SetType sets the "type" field. +func (nc *NodeCreate) SetType(n node.Type) *NodeCreate { + nc.mutation.SetType(n) + return nc +} + +// SetServer sets the "server" field. +func (nc *NodeCreate) SetServer(s string) *NodeCreate { + nc.mutation.SetServer(s) + return nc +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (nc *NodeCreate) SetNillableServer(s *string) *NodeCreate { + if s != nil { + nc.SetServer(*s) + } + return nc +} + +// SetSlaveKey sets the "slave_key" field. +func (nc *NodeCreate) SetSlaveKey(s string) *NodeCreate { + nc.mutation.SetSlaveKey(s) + return nc +} + +// SetNillableSlaveKey sets the "slave_key" field if the given value is not nil. +func (nc *NodeCreate) SetNillableSlaveKey(s *string) *NodeCreate { + if s != nil { + nc.SetSlaveKey(*s) + } + return nc +} + +// SetCapabilities sets the "capabilities" field. +func (nc *NodeCreate) SetCapabilities(bs *boolset.BooleanSet) *NodeCreate { + nc.mutation.SetCapabilities(bs) + return nc +} + +// SetSettings sets the "settings" field. +func (nc *NodeCreate) SetSettings(ts *types.NodeSetting) *NodeCreate { + nc.mutation.SetSettings(ts) + return nc +} + +// SetWeight sets the "weight" field. +func (nc *NodeCreate) SetWeight(i int) *NodeCreate { + nc.mutation.SetWeight(i) + return nc +} + +// SetNillableWeight sets the "weight" field if the given value is not nil. +func (nc *NodeCreate) SetNillableWeight(i *int) *NodeCreate { + if i != nil { + nc.SetWeight(*i) + } + return nc +} + +// AddStoragePolicyIDs adds the "storage_policy" edge to the StoragePolicy entity by IDs. +func (nc *NodeCreate) AddStoragePolicyIDs(ids ...int) *NodeCreate { + nc.mutation.AddStoragePolicyIDs(ids...) + return nc +} + +// AddStoragePolicy adds the "storage_policy" edges to the StoragePolicy entity. +func (nc *NodeCreate) AddStoragePolicy(s ...*StoragePolicy) *NodeCreate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return nc.AddStoragePolicyIDs(ids...) +} + +// Mutation returns the NodeMutation object of the builder. +func (nc *NodeCreate) Mutation() *NodeMutation { + return nc.mutation +} + +// Save creates the Node in the database. +func (nc *NodeCreate) Save(ctx context.Context) (*Node, error) { + if err := nc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, nc.sqlSave, nc.mutation, nc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (nc *NodeCreate) SaveX(ctx context.Context) *Node { + v, err := nc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (nc *NodeCreate) Exec(ctx context.Context) error { + _, err := nc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nc *NodeCreate) ExecX(ctx context.Context) { + if err := nc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (nc *NodeCreate) defaults() error { + if _, ok := nc.mutation.CreatedAt(); !ok { + if node.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized node.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := node.DefaultCreatedAt() + nc.mutation.SetCreatedAt(v) + } + if _, ok := nc.mutation.UpdatedAt(); !ok { + if node.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized node.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := node.DefaultUpdatedAt() + nc.mutation.SetUpdatedAt(v) + } + if _, ok := nc.mutation.Settings(); !ok { + v := node.DefaultSettings + nc.mutation.SetSettings(v) + } + if _, ok := nc.mutation.Weight(); !ok { + v := node.DefaultWeight + nc.mutation.SetWeight(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (nc *NodeCreate) check() error { + if _, ok := nc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Node.created_at"`)} + } + if _, ok := nc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Node.updated_at"`)} + } + if _, ok := nc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Node.status"`)} + } + if v, ok := nc.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } + if _, ok := nc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Node.name"`)} + } + if _, ok := nc.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "Node.type"`)} + } + if v, ok := nc.mutation.GetType(); ok { + if err := node.TypeValidator(v); err != nil { + return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "Node.type": %w`, err)} + } + } + if _, ok := nc.mutation.Capabilities(); !ok { + return &ValidationError{Name: "capabilities", err: errors.New(`ent: missing required field "Node.capabilities"`)} + } + if _, ok := nc.mutation.Weight(); !ok { + return &ValidationError{Name: "weight", err: errors.New(`ent: missing required field "Node.weight"`)} + } + return nil +} + +func (nc *NodeCreate) sqlSave(ctx context.Context) (*Node, error) { + if err := nc.check(); err != nil { + return nil, err + } + _node, _spec := nc.createSpec() + if err := sqlgraph.CreateNode(ctx, nc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + nc.mutation.id = &_node.ID + nc.mutation.done = true + return _node, nil +} + +func (nc *NodeCreate) createSpec() (*Node, *sqlgraph.CreateSpec) { + var ( + _node = &Node{config: nc.config} + _spec = sqlgraph.NewCreateSpec(node.Table, sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt)) + ) + + if id, ok := nc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = nc.conflict + if value, ok := nc.mutation.CreatedAt(); ok { + _spec.SetField(node.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := nc.mutation.UpdatedAt(); ok { + _spec.SetField(node.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := nc.mutation.DeletedAt(); ok { + _spec.SetField(node.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := nc.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := nc.mutation.Name(); ok { + _spec.SetField(node.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := nc.mutation.GetType(); ok { + _spec.SetField(node.FieldType, field.TypeEnum, value) + _node.Type = value + } + if value, ok := nc.mutation.Server(); ok { + _spec.SetField(node.FieldServer, field.TypeString, value) + _node.Server = value + } + if value, ok := nc.mutation.SlaveKey(); ok { + _spec.SetField(node.FieldSlaveKey, field.TypeString, value) + _node.SlaveKey = value + } + if value, ok := nc.mutation.Capabilities(); ok { + _spec.SetField(node.FieldCapabilities, field.TypeBytes, value) + _node.Capabilities = value + } + if value, ok := nc.mutation.Settings(); ok { + _spec.SetField(node.FieldSettings, field.TypeJSON, value) + _node.Settings = value + } + if value, ok := nc.mutation.Weight(); ok { + _spec.SetField(node.FieldWeight, field.TypeInt, value) + _node.Weight = value + } + if nodes := nc.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Node.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.NodeUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (nc *NodeCreate) OnConflict(opts ...sql.ConflictOption) *NodeUpsertOne { + nc.conflict = opts + return &NodeUpsertOne{ + create: nc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (nc *NodeCreate) OnConflictColumns(columns ...string) *NodeUpsertOne { + nc.conflict = append(nc.conflict, sql.ConflictColumns(columns...)) + return &NodeUpsertOne{ + create: nc, + } +} + +type ( + // NodeUpsertOne is the builder for "upsert"-ing + // one Node node. + NodeUpsertOne struct { + create *NodeCreate + } + + // NodeUpsert is the "OnConflict" setter. + NodeUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *NodeUpsert) SetUpdatedAt(v time.Time) *NodeUpsert { + u.Set(node.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *NodeUpsert) UpdateUpdatedAt() *NodeUpsert { + u.SetExcluded(node.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *NodeUpsert) SetDeletedAt(v time.Time) *NodeUpsert { + u.Set(node.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *NodeUpsert) UpdateDeletedAt() *NodeUpsert { + u.SetExcluded(node.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *NodeUpsert) ClearDeletedAt() *NodeUpsert { + u.SetNull(node.FieldDeletedAt) + return u +} + +// SetStatus sets the "status" field. +func (u *NodeUpsert) SetStatus(v node.Status) *NodeUpsert { + u.Set(node.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsert) UpdateStatus() *NodeUpsert { + u.SetExcluded(node.FieldStatus) + return u +} + +// SetName sets the "name" field. +func (u *NodeUpsert) SetName(v string) *NodeUpsert { + u.Set(node.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *NodeUpsert) UpdateName() *NodeUpsert { + u.SetExcluded(node.FieldName) + return u +} + +// SetType sets the "type" field. +func (u *NodeUpsert) SetType(v node.Type) *NodeUpsert { + u.Set(node.FieldType, v) + return u +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *NodeUpsert) UpdateType() *NodeUpsert { + u.SetExcluded(node.FieldType) + return u +} + +// SetServer sets the "server" field. +func (u *NodeUpsert) SetServer(v string) *NodeUpsert { + u.Set(node.FieldServer, v) + return u +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *NodeUpsert) UpdateServer() *NodeUpsert { + u.SetExcluded(node.FieldServer) + return u +} + +// ClearServer clears the value of the "server" field. +func (u *NodeUpsert) ClearServer() *NodeUpsert { + u.SetNull(node.FieldServer) + return u +} + +// SetSlaveKey sets the "slave_key" field. +func (u *NodeUpsert) SetSlaveKey(v string) *NodeUpsert { + u.Set(node.FieldSlaveKey, v) + return u +} + +// UpdateSlaveKey sets the "slave_key" field to the value that was provided on create. +func (u *NodeUpsert) UpdateSlaveKey() *NodeUpsert { + u.SetExcluded(node.FieldSlaveKey) + return u +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (u *NodeUpsert) ClearSlaveKey() *NodeUpsert { + u.SetNull(node.FieldSlaveKey) + return u +} + +// SetCapabilities sets the "capabilities" field. +func (u *NodeUpsert) SetCapabilities(v *boolset.BooleanSet) *NodeUpsert { + u.Set(node.FieldCapabilities, v) + return u +} + +// UpdateCapabilities sets the "capabilities" field to the value that was provided on create. +func (u *NodeUpsert) UpdateCapabilities() *NodeUpsert { + u.SetExcluded(node.FieldCapabilities) + return u +} + +// SetSettings sets the "settings" field. +func (u *NodeUpsert) SetSettings(v *types.NodeSetting) *NodeUpsert { + u.Set(node.FieldSettings, v) + return u +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *NodeUpsert) UpdateSettings() *NodeUpsert { + u.SetExcluded(node.FieldSettings) + return u +} + +// ClearSettings clears the value of the "settings" field. +func (u *NodeUpsert) ClearSettings() *NodeUpsert { + u.SetNull(node.FieldSettings) + return u +} + +// SetWeight sets the "weight" field. +func (u *NodeUpsert) SetWeight(v int) *NodeUpsert { + u.Set(node.FieldWeight, v) + return u +} + +// UpdateWeight sets the "weight" field to the value that was provided on create. +func (u *NodeUpsert) UpdateWeight() *NodeUpsert { + u.SetExcluded(node.FieldWeight) + return u +} + +// AddWeight adds v to the "weight" field. +func (u *NodeUpsert) AddWeight(v int) *NodeUpsert { + u.Add(node.FieldWeight, v) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *NodeUpsertOne) UpdateNewValues() *NodeUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(node.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *NodeUpsertOne) Ignore() *NodeUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *NodeUpsertOne) DoNothing() *NodeUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the NodeCreate.OnConflict +// documentation for more info. +func (u *NodeUpsertOne) Update(set func(*NodeUpsert)) *NodeUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&NodeUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *NodeUpsertOne) SetUpdatedAt(v time.Time) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateUpdatedAt() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *NodeUpsertOne) SetDeletedAt(v time.Time) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateDeletedAt() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *NodeUpsertOne) ClearDeletedAt() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearDeletedAt() + }) +} + +// SetStatus sets the "status" field. +func (u *NodeUpsertOne) SetStatus(v node.Status) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateStatus() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatus() + }) +} + +// SetName sets the "name" field. +func (u *NodeUpsertOne) SetName(v string) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateName() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateName() + }) +} + +// SetType sets the "type" field. +func (u *NodeUpsertOne) SetType(v node.Type) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateType() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateType() + }) +} + +// SetServer sets the "server" field. +func (u *NodeUpsertOne) SetServer(v string) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetServer(v) + }) +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateServer() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateServer() + }) +} + +// ClearServer clears the value of the "server" field. +func (u *NodeUpsertOne) ClearServer() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearServer() + }) +} + +// SetSlaveKey sets the "slave_key" field. +func (u *NodeUpsertOne) SetSlaveKey(v string) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetSlaveKey(v) + }) +} + +// UpdateSlaveKey sets the "slave_key" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateSlaveKey() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateSlaveKey() + }) +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (u *NodeUpsertOne) ClearSlaveKey() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearSlaveKey() + }) +} + +// SetCapabilities sets the "capabilities" field. +func (u *NodeUpsertOne) SetCapabilities(v *boolset.BooleanSet) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetCapabilities(v) + }) +} + +// UpdateCapabilities sets the "capabilities" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateCapabilities() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateCapabilities() + }) +} + +// SetSettings sets the "settings" field. +func (u *NodeUpsertOne) SetSettings(v *types.NodeSetting) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateSettings() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *NodeUpsertOne) ClearSettings() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearSettings() + }) +} + +// SetWeight sets the "weight" field. +func (u *NodeUpsertOne) SetWeight(v int) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetWeight(v) + }) +} + +// AddWeight adds v to the "weight" field. +func (u *NodeUpsertOne) AddWeight(v int) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.AddWeight(v) + }) +} + +// UpdateWeight sets the "weight" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateWeight() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateWeight() + }) +} + +// Exec executes the query. +func (u *NodeUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for NodeCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *NodeUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *NodeUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *NodeUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *NodeCreate) SetRawID(t int) *NodeCreate { + m.mutation.SetRawID(t) + return m +} + +// NodeCreateBulk is the builder for creating many Node entities in bulk. +type NodeCreateBulk struct { + config + err error + builders []*NodeCreate + conflict []sql.ConflictOption +} + +// Save creates the Node entities in the database. +func (ncb *NodeCreateBulk) Save(ctx context.Context) ([]*Node, error) { + if ncb.err != nil { + return nil, ncb.err + } + specs := make([]*sqlgraph.CreateSpec, len(ncb.builders)) + nodes := make([]*Node, len(ncb.builders)) + mutators := make([]Mutator, len(ncb.builders)) + for i := range ncb.builders { + func(i int, root context.Context) { + builder := ncb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*NodeMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, ncb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = ncb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, ncb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, ncb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (ncb *NodeCreateBulk) SaveX(ctx context.Context) []*Node { + v, err := ncb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (ncb *NodeCreateBulk) Exec(ctx context.Context) error { + _, err := ncb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ncb *NodeCreateBulk) ExecX(ctx context.Context) { + if err := ncb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Node.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.NodeUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (ncb *NodeCreateBulk) OnConflict(opts ...sql.ConflictOption) *NodeUpsertBulk { + ncb.conflict = opts + return &NodeUpsertBulk{ + create: ncb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (ncb *NodeCreateBulk) OnConflictColumns(columns ...string) *NodeUpsertBulk { + ncb.conflict = append(ncb.conflict, sql.ConflictColumns(columns...)) + return &NodeUpsertBulk{ + create: ncb, + } +} + +// NodeUpsertBulk is the builder for "upsert"-ing +// a bulk of Node nodes. +type NodeUpsertBulk struct { + create *NodeCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *NodeUpsertBulk) UpdateNewValues() *NodeUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(node.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Node.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *NodeUpsertBulk) Ignore() *NodeUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *NodeUpsertBulk) DoNothing() *NodeUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the NodeCreateBulk.OnConflict +// documentation for more info. +func (u *NodeUpsertBulk) Update(set func(*NodeUpsert)) *NodeUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&NodeUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *NodeUpsertBulk) SetUpdatedAt(v time.Time) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateUpdatedAt() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *NodeUpsertBulk) SetDeletedAt(v time.Time) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateDeletedAt() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *NodeUpsertBulk) ClearDeletedAt() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearDeletedAt() + }) +} + +// SetStatus sets the "status" field. +func (u *NodeUpsertBulk) SetStatus(v node.Status) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateStatus() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatus() + }) +} + +// SetName sets the "name" field. +func (u *NodeUpsertBulk) SetName(v string) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateName() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateName() + }) +} + +// SetType sets the "type" field. +func (u *NodeUpsertBulk) SetType(v node.Type) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateType() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateType() + }) +} + +// SetServer sets the "server" field. +func (u *NodeUpsertBulk) SetServer(v string) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetServer(v) + }) +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateServer() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateServer() + }) +} + +// ClearServer clears the value of the "server" field. +func (u *NodeUpsertBulk) ClearServer() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearServer() + }) +} + +// SetSlaveKey sets the "slave_key" field. +func (u *NodeUpsertBulk) SetSlaveKey(v string) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetSlaveKey(v) + }) +} + +// UpdateSlaveKey sets the "slave_key" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateSlaveKey() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateSlaveKey() + }) +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (u *NodeUpsertBulk) ClearSlaveKey() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearSlaveKey() + }) +} + +// SetCapabilities sets the "capabilities" field. +func (u *NodeUpsertBulk) SetCapabilities(v *boolset.BooleanSet) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetCapabilities(v) + }) +} + +// UpdateCapabilities sets the "capabilities" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateCapabilities() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateCapabilities() + }) +} + +// SetSettings sets the "settings" field. +func (u *NodeUpsertBulk) SetSettings(v *types.NodeSetting) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateSettings() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *NodeUpsertBulk) ClearSettings() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearSettings() + }) +} + +// SetWeight sets the "weight" field. +func (u *NodeUpsertBulk) SetWeight(v int) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetWeight(v) + }) +} + +// AddWeight adds v to the "weight" field. +func (u *NodeUpsertBulk) AddWeight(v int) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.AddWeight(v) + }) +} + +// UpdateWeight sets the "weight" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateWeight() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateWeight() + }) +} + +// Exec executes the query. +func (u *NodeUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the NodeCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for NodeCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *NodeUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/node_delete.go b/ent/node_delete.go new file mode 100644 index 00000000..5e575cba --- /dev/null +++ b/ent/node_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// NodeDelete is the builder for deleting a Node entity. +type NodeDelete struct { + config + hooks []Hook + mutation *NodeMutation +} + +// Where appends a list predicates to the NodeDelete builder. +func (nd *NodeDelete) Where(ps ...predicate.Node) *NodeDelete { + nd.mutation.Where(ps...) + return nd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (nd *NodeDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, nd.sqlExec, nd.mutation, nd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (nd *NodeDelete) ExecX(ctx context.Context) int { + n, err := nd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (nd *NodeDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(node.Table, sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt)) + if ps := nd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, nd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + nd.mutation.done = true + return affected, err +} + +// NodeDeleteOne is the builder for deleting a single Node entity. +type NodeDeleteOne struct { + nd *NodeDelete +} + +// Where appends a list predicates to the NodeDelete builder. +func (ndo *NodeDeleteOne) Where(ps ...predicate.Node) *NodeDeleteOne { + ndo.nd.mutation.Where(ps...) + return ndo +} + +// Exec executes the deletion query. +func (ndo *NodeDeleteOne) Exec(ctx context.Context) error { + n, err := ndo.nd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{node.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (ndo *NodeDeleteOne) ExecX(ctx context.Context) { + if err := ndo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/node_query.go b/ent/node_query.go new file mode 100644 index 00000000..f12bf905 --- /dev/null +++ b/ent/node_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" +) + +// NodeQuery is the builder for querying Node entities. +type NodeQuery struct { + config + ctx *QueryContext + order []node.OrderOption + inters []Interceptor + predicates []predicate.Node + withStoragePolicy *StoragePolicyQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the NodeQuery builder. +func (nq *NodeQuery) Where(ps ...predicate.Node) *NodeQuery { + nq.predicates = append(nq.predicates, ps...) + return nq +} + +// Limit the number of records to be returned by this query. +func (nq *NodeQuery) Limit(limit int) *NodeQuery { + nq.ctx.Limit = &limit + return nq +} + +// Offset to start from. +func (nq *NodeQuery) Offset(offset int) *NodeQuery { + nq.ctx.Offset = &offset + return nq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (nq *NodeQuery) Unique(unique bool) *NodeQuery { + nq.ctx.Unique = &unique + return nq +} + +// Order specifies how the records should be ordered. +func (nq *NodeQuery) Order(o ...node.OrderOption) *NodeQuery { + nq.order = append(nq.order, o...) + return nq +} + +// QueryStoragePolicy chains the current query on the "storage_policy" edge. +func (nq *NodeQuery) QueryStoragePolicy() *StoragePolicyQuery { + query := (&StoragePolicyClient{config: nq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := nq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := nq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(node.Table, node.FieldID, selector), + sqlgraph.To(storagepolicy.Table, storagepolicy.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, node.StoragePolicyTable, node.StoragePolicyColumn), + ) + fromU = sqlgraph.SetNeighbors(nq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Node entity from the query. +// Returns a *NotFoundError when no Node was found. +func (nq *NodeQuery) First(ctx context.Context) (*Node, error) { + nodes, err := nq.Limit(1).All(setContextOp(ctx, nq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{node.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (nq *NodeQuery) FirstX(ctx context.Context) *Node { + node, err := nq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Node ID from the query. +// Returns a *NotFoundError when no Node ID was found. +func (nq *NodeQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = nq.Limit(1).IDs(setContextOp(ctx, nq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{node.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (nq *NodeQuery) FirstIDX(ctx context.Context) int { + id, err := nq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Node entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Node entity is found. +// Returns a *NotFoundError when no Node entities are found. +func (nq *NodeQuery) Only(ctx context.Context) (*Node, error) { + nodes, err := nq.Limit(2).All(setContextOp(ctx, nq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{node.Label} + default: + return nil, &NotSingularError{node.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (nq *NodeQuery) OnlyX(ctx context.Context) *Node { + node, err := nq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Node ID in the query. +// Returns a *NotSingularError when more than one Node ID is found. +// Returns a *NotFoundError when no entities are found. +func (nq *NodeQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = nq.Limit(2).IDs(setContextOp(ctx, nq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{node.Label} + default: + err = &NotSingularError{node.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (nq *NodeQuery) OnlyIDX(ctx context.Context) int { + id, err := nq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Nodes. +func (nq *NodeQuery) All(ctx context.Context) ([]*Node, error) { + ctx = setContextOp(ctx, nq.ctx, "All") + if err := nq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Node, *NodeQuery]() + return withInterceptors[[]*Node](ctx, nq, qr, nq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (nq *NodeQuery) AllX(ctx context.Context) []*Node { + nodes, err := nq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Node IDs. +func (nq *NodeQuery) IDs(ctx context.Context) (ids []int, err error) { + if nq.ctx.Unique == nil && nq.path != nil { + nq.Unique(true) + } + ctx = setContextOp(ctx, nq.ctx, "IDs") + if err = nq.Select(node.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (nq *NodeQuery) IDsX(ctx context.Context) []int { + ids, err := nq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (nq *NodeQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, nq.ctx, "Count") + if err := nq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, nq, querierCount[*NodeQuery](), nq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (nq *NodeQuery) CountX(ctx context.Context) int { + count, err := nq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (nq *NodeQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, nq.ctx, "Exist") + switch _, err := nq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (nq *NodeQuery) ExistX(ctx context.Context) bool { + exist, err := nq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the NodeQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (nq *NodeQuery) Clone() *NodeQuery { + if nq == nil { + return nil + } + return &NodeQuery{ + config: nq.config, + ctx: nq.ctx.Clone(), + order: append([]node.OrderOption{}, nq.order...), + inters: append([]Interceptor{}, nq.inters...), + predicates: append([]predicate.Node{}, nq.predicates...), + withStoragePolicy: nq.withStoragePolicy.Clone(), + // clone intermediate query. + sql: nq.sql.Clone(), + path: nq.path, + } +} + +// WithStoragePolicy tells the query-builder to eager-load the nodes that are connected to +// the "storage_policy" edge. The optional arguments are used to configure the query builder of the edge. +func (nq *NodeQuery) WithStoragePolicy(opts ...func(*StoragePolicyQuery)) *NodeQuery { + query := (&StoragePolicyClient{config: nq.config}).Query() + for _, opt := range opts { + opt(query) + } + nq.withStoragePolicy = query + return nq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Node.Query(). +// GroupBy(node.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (nq *NodeQuery) GroupBy(field string, fields ...string) *NodeGroupBy { + nq.ctx.Fields = append([]string{field}, fields...) + grbuild := &NodeGroupBy{build: nq} + grbuild.flds = &nq.ctx.Fields + grbuild.label = node.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Node.Query(). +// Select(node.FieldCreatedAt). +// Scan(ctx, &v) +func (nq *NodeQuery) Select(fields ...string) *NodeSelect { + nq.ctx.Fields = append(nq.ctx.Fields, fields...) + sbuild := &NodeSelect{NodeQuery: nq} + sbuild.label = node.Label + sbuild.flds, sbuild.scan = &nq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a NodeSelect configured with the given aggregations. +func (nq *NodeQuery) Aggregate(fns ...AggregateFunc) *NodeSelect { + return nq.Select().Aggregate(fns...) +} + +func (nq *NodeQuery) prepareQuery(ctx context.Context) error { + for _, inter := range nq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, nq); err != nil { + return err + } + } + } + for _, f := range nq.ctx.Fields { + if !node.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if nq.path != nil { + prev, err := nq.path(ctx) + if err != nil { + return err + } + nq.sql = prev + } + return nil +} + +func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { + var ( + nodes = []*Node{} + _spec = nq.querySpec() + loadedTypes = [1]bool{ + nq.withStoragePolicy != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Node).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Node{config: nq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := nq.withStoragePolicy; query != nil { + if err := nq.loadStoragePolicy(ctx, query, nodes, + func(n *Node) { n.Edges.StoragePolicy = []*StoragePolicy{} }, + func(n *Node, e *StoragePolicy) { n.Edges.StoragePolicy = append(n.Edges.StoragePolicy, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (nq *NodeQuery) loadStoragePolicy(ctx context.Context, query *StoragePolicyQuery, nodes []*Node, init func(*Node), assign func(*Node, *StoragePolicy)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(storagepolicy.FieldNodeID) + } + query.Where(predicate.StoragePolicy(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(node.StoragePolicyColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.NodeID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "node_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { + _spec := nq.querySpec() + _spec.Node.Columns = nq.ctx.Fields + if len(nq.ctx.Fields) > 0 { + _spec.Unique = nq.ctx.Unique != nil && *nq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, nq.driver, _spec) +} + +func (nq *NodeQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(node.Table, node.Columns, sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt)) + _spec.From = nq.sql + if unique := nq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if nq.path != nil { + _spec.Unique = true + } + if fields := nq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, node.FieldID) + for i := range fields { + if fields[i] != node.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := nq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := nq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := nq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := nq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (nq *NodeQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(nq.driver.Dialect()) + t1 := builder.Table(node.Table) + columns := nq.ctx.Fields + if len(columns) == 0 { + columns = node.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if nq.sql != nil { + selector = nq.sql + selector.Select(selector.Columns(columns...)...) + } + if nq.ctx.Unique != nil && *nq.ctx.Unique { + selector.Distinct() + } + for _, p := range nq.predicates { + p(selector) + } + for _, p := range nq.order { + p(selector) + } + if offset := nq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := nq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// NodeGroupBy is the group-by builder for Node entities. +type NodeGroupBy struct { + selector + build *NodeQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (ngb *NodeGroupBy) Aggregate(fns ...AggregateFunc) *NodeGroupBy { + ngb.fns = append(ngb.fns, fns...) + return ngb +} + +// Scan applies the selector query and scans the result into the given value. +func (ngb *NodeGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ngb.build.ctx, "GroupBy") + if err := ngb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*NodeQuery, *NodeGroupBy](ctx, ngb.build, ngb, ngb.build.inters, v) +} + +func (ngb *NodeGroupBy) sqlScan(ctx context.Context, root *NodeQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(ngb.fns)) + for _, fn := range ngb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*ngb.flds)+len(ngb.fns)) + for _, f := range *ngb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*ngb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ngb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// NodeSelect is the builder for selecting fields of Node entities. +type NodeSelect struct { + *NodeQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ns *NodeSelect) Aggregate(fns ...AggregateFunc) *NodeSelect { + ns.fns = append(ns.fns, fns...) + return ns +} + +// Scan applies the selector query and scans the result into the given value. +func (ns *NodeSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ns.ctx, "Select") + if err := ns.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*NodeQuery, *NodeSelect](ctx, ns.NodeQuery, ns, ns.inters, v) +} + +func (ns *NodeSelect) sqlScan(ctx context.Context, root *NodeQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ns.fns)) + for _, fn := range ns.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ns.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ns.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/node_update.go b/ent/node_update.go new file mode 100644 index 00000000..555f2435 --- /dev/null +++ b/ent/node_update.go @@ -0,0 +1,791 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// NodeUpdate is the builder for updating Node entities. +type NodeUpdate struct { + config + hooks []Hook + mutation *NodeMutation +} + +// Where appends a list predicates to the NodeUpdate builder. +func (nu *NodeUpdate) Where(ps ...predicate.Node) *NodeUpdate { + nu.mutation.Where(ps...) + return nu +} + +// SetUpdatedAt sets the "updated_at" field. +func (nu *NodeUpdate) SetUpdatedAt(t time.Time) *NodeUpdate { + nu.mutation.SetUpdatedAt(t) + return nu +} + +// SetDeletedAt sets the "deleted_at" field. +func (nu *NodeUpdate) SetDeletedAt(t time.Time) *NodeUpdate { + nu.mutation.SetDeletedAt(t) + return nu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableDeletedAt(t *time.Time) *NodeUpdate { + if t != nil { + nu.SetDeletedAt(*t) + } + return nu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (nu *NodeUpdate) ClearDeletedAt() *NodeUpdate { + nu.mutation.ClearDeletedAt() + return nu +} + +// SetStatus sets the "status" field. +func (nu *NodeUpdate) SetStatus(n node.Status) *NodeUpdate { + nu.mutation.SetStatus(n) + return nu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableStatus(n *node.Status) *NodeUpdate { + if n != nil { + nu.SetStatus(*n) + } + return nu +} + +// SetName sets the "name" field. +func (nu *NodeUpdate) SetName(s string) *NodeUpdate { + nu.mutation.SetName(s) + return nu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableName(s *string) *NodeUpdate { + if s != nil { + nu.SetName(*s) + } + return nu +} + +// SetType sets the "type" field. +func (nu *NodeUpdate) SetType(n node.Type) *NodeUpdate { + nu.mutation.SetType(n) + return nu +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableType(n *node.Type) *NodeUpdate { + if n != nil { + nu.SetType(*n) + } + return nu +} + +// SetServer sets the "server" field. +func (nu *NodeUpdate) SetServer(s string) *NodeUpdate { + nu.mutation.SetServer(s) + return nu +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableServer(s *string) *NodeUpdate { + if s != nil { + nu.SetServer(*s) + } + return nu +} + +// ClearServer clears the value of the "server" field. +func (nu *NodeUpdate) ClearServer() *NodeUpdate { + nu.mutation.ClearServer() + return nu +} + +// SetSlaveKey sets the "slave_key" field. +func (nu *NodeUpdate) SetSlaveKey(s string) *NodeUpdate { + nu.mutation.SetSlaveKey(s) + return nu +} + +// SetNillableSlaveKey sets the "slave_key" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableSlaveKey(s *string) *NodeUpdate { + if s != nil { + nu.SetSlaveKey(*s) + } + return nu +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (nu *NodeUpdate) ClearSlaveKey() *NodeUpdate { + nu.mutation.ClearSlaveKey() + return nu +} + +// SetCapabilities sets the "capabilities" field. +func (nu *NodeUpdate) SetCapabilities(bs *boolset.BooleanSet) *NodeUpdate { + nu.mutation.SetCapabilities(bs) + return nu +} + +// SetSettings sets the "settings" field. +func (nu *NodeUpdate) SetSettings(ts *types.NodeSetting) *NodeUpdate { + nu.mutation.SetSettings(ts) + return nu +} + +// ClearSettings clears the value of the "settings" field. +func (nu *NodeUpdate) ClearSettings() *NodeUpdate { + nu.mutation.ClearSettings() + return nu +} + +// SetWeight sets the "weight" field. +func (nu *NodeUpdate) SetWeight(i int) *NodeUpdate { + nu.mutation.ResetWeight() + nu.mutation.SetWeight(i) + return nu +} + +// SetNillableWeight sets the "weight" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableWeight(i *int) *NodeUpdate { + if i != nil { + nu.SetWeight(*i) + } + return nu +} + +// AddWeight adds i to the "weight" field. +func (nu *NodeUpdate) AddWeight(i int) *NodeUpdate { + nu.mutation.AddWeight(i) + return nu +} + +// AddStoragePolicyIDs adds the "storage_policy" edge to the StoragePolicy entity by IDs. +func (nu *NodeUpdate) AddStoragePolicyIDs(ids ...int) *NodeUpdate { + nu.mutation.AddStoragePolicyIDs(ids...) + return nu +} + +// AddStoragePolicy adds the "storage_policy" edges to the StoragePolicy entity. +func (nu *NodeUpdate) AddStoragePolicy(s ...*StoragePolicy) *NodeUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return nu.AddStoragePolicyIDs(ids...) +} + +// Mutation returns the NodeMutation object of the builder. +func (nu *NodeUpdate) Mutation() *NodeMutation { + return nu.mutation +} + +// ClearStoragePolicy clears all "storage_policy" edges to the StoragePolicy entity. +func (nu *NodeUpdate) ClearStoragePolicy() *NodeUpdate { + nu.mutation.ClearStoragePolicy() + return nu +} + +// RemoveStoragePolicyIDs removes the "storage_policy" edge to StoragePolicy entities by IDs. +func (nu *NodeUpdate) RemoveStoragePolicyIDs(ids ...int) *NodeUpdate { + nu.mutation.RemoveStoragePolicyIDs(ids...) + return nu +} + +// RemoveStoragePolicy removes "storage_policy" edges to StoragePolicy entities. +func (nu *NodeUpdate) RemoveStoragePolicy(s ...*StoragePolicy) *NodeUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return nu.RemoveStoragePolicyIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (nu *NodeUpdate) Save(ctx context.Context) (int, error) { + if err := nu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, nu.sqlSave, nu.mutation, nu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (nu *NodeUpdate) SaveX(ctx context.Context) int { + affected, err := nu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (nu *NodeUpdate) Exec(ctx context.Context) error { + _, err := nu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nu *NodeUpdate) ExecX(ctx context.Context) { + if err := nu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (nu *NodeUpdate) defaults() error { + if _, ok := nu.mutation.UpdatedAt(); !ok { + if node.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized node.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := node.UpdateDefaultUpdatedAt() + nu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (nu *NodeUpdate) check() error { + if v, ok := nu.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } + if v, ok := nu.mutation.GetType(); ok { + if err := node.TypeValidator(v); err != nil { + return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "Node.type": %w`, err)} + } + } + return nil +} + +func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := nu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(node.Table, node.Columns, sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt)) + if ps := nu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := nu.mutation.UpdatedAt(); ok { + _spec.SetField(node.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := nu.mutation.DeletedAt(); ok { + _spec.SetField(node.FieldDeletedAt, field.TypeTime, value) + } + if nu.mutation.DeletedAtCleared() { + _spec.ClearField(node.FieldDeletedAt, field.TypeTime) + } + if value, ok := nu.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + } + if value, ok := nu.mutation.Name(); ok { + _spec.SetField(node.FieldName, field.TypeString, value) + } + if value, ok := nu.mutation.GetType(); ok { + _spec.SetField(node.FieldType, field.TypeEnum, value) + } + if value, ok := nu.mutation.Server(); ok { + _spec.SetField(node.FieldServer, field.TypeString, value) + } + if nu.mutation.ServerCleared() { + _spec.ClearField(node.FieldServer, field.TypeString) + } + if value, ok := nu.mutation.SlaveKey(); ok { + _spec.SetField(node.FieldSlaveKey, field.TypeString, value) + } + if nu.mutation.SlaveKeyCleared() { + _spec.ClearField(node.FieldSlaveKey, field.TypeString) + } + if value, ok := nu.mutation.Capabilities(); ok { + _spec.SetField(node.FieldCapabilities, field.TypeBytes, value) + } + if value, ok := nu.mutation.Settings(); ok { + _spec.SetField(node.FieldSettings, field.TypeJSON, value) + } + if nu.mutation.SettingsCleared() { + _spec.ClearField(node.FieldSettings, field.TypeJSON) + } + if value, ok := nu.mutation.Weight(); ok { + _spec.SetField(node.FieldWeight, field.TypeInt, value) + } + if value, ok := nu.mutation.AddedWeight(); ok { + _spec.AddField(node.FieldWeight, field.TypeInt, value) + } + if nu.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nu.mutation.RemovedStoragePolicyIDs(); len(nodes) > 0 && !nu.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nu.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, nu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{node.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + nu.mutation.done = true + return n, nil +} + +// NodeUpdateOne is the builder for updating a single Node entity. +type NodeUpdateOne struct { + config + fields []string + hooks []Hook + mutation *NodeMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (nuo *NodeUpdateOne) SetUpdatedAt(t time.Time) *NodeUpdateOne { + nuo.mutation.SetUpdatedAt(t) + return nuo +} + +// SetDeletedAt sets the "deleted_at" field. +func (nuo *NodeUpdateOne) SetDeletedAt(t time.Time) *NodeUpdateOne { + nuo.mutation.SetDeletedAt(t) + return nuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableDeletedAt(t *time.Time) *NodeUpdateOne { + if t != nil { + nuo.SetDeletedAt(*t) + } + return nuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (nuo *NodeUpdateOne) ClearDeletedAt() *NodeUpdateOne { + nuo.mutation.ClearDeletedAt() + return nuo +} + +// SetStatus sets the "status" field. +func (nuo *NodeUpdateOne) SetStatus(n node.Status) *NodeUpdateOne { + nuo.mutation.SetStatus(n) + return nuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableStatus(n *node.Status) *NodeUpdateOne { + if n != nil { + nuo.SetStatus(*n) + } + return nuo +} + +// SetName sets the "name" field. +func (nuo *NodeUpdateOne) SetName(s string) *NodeUpdateOne { + nuo.mutation.SetName(s) + return nuo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableName(s *string) *NodeUpdateOne { + if s != nil { + nuo.SetName(*s) + } + return nuo +} + +// SetType sets the "type" field. +func (nuo *NodeUpdateOne) SetType(n node.Type) *NodeUpdateOne { + nuo.mutation.SetType(n) + return nuo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableType(n *node.Type) *NodeUpdateOne { + if n != nil { + nuo.SetType(*n) + } + return nuo +} + +// SetServer sets the "server" field. +func (nuo *NodeUpdateOne) SetServer(s string) *NodeUpdateOne { + nuo.mutation.SetServer(s) + return nuo +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableServer(s *string) *NodeUpdateOne { + if s != nil { + nuo.SetServer(*s) + } + return nuo +} + +// ClearServer clears the value of the "server" field. +func (nuo *NodeUpdateOne) ClearServer() *NodeUpdateOne { + nuo.mutation.ClearServer() + return nuo +} + +// SetSlaveKey sets the "slave_key" field. +func (nuo *NodeUpdateOne) SetSlaveKey(s string) *NodeUpdateOne { + nuo.mutation.SetSlaveKey(s) + return nuo +} + +// SetNillableSlaveKey sets the "slave_key" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableSlaveKey(s *string) *NodeUpdateOne { + if s != nil { + nuo.SetSlaveKey(*s) + } + return nuo +} + +// ClearSlaveKey clears the value of the "slave_key" field. +func (nuo *NodeUpdateOne) ClearSlaveKey() *NodeUpdateOne { + nuo.mutation.ClearSlaveKey() + return nuo +} + +// SetCapabilities sets the "capabilities" field. +func (nuo *NodeUpdateOne) SetCapabilities(bs *boolset.BooleanSet) *NodeUpdateOne { + nuo.mutation.SetCapabilities(bs) + return nuo +} + +// SetSettings sets the "settings" field. +func (nuo *NodeUpdateOne) SetSettings(ts *types.NodeSetting) *NodeUpdateOne { + nuo.mutation.SetSettings(ts) + return nuo +} + +// ClearSettings clears the value of the "settings" field. +func (nuo *NodeUpdateOne) ClearSettings() *NodeUpdateOne { + nuo.mutation.ClearSettings() + return nuo +} + +// SetWeight sets the "weight" field. +func (nuo *NodeUpdateOne) SetWeight(i int) *NodeUpdateOne { + nuo.mutation.ResetWeight() + nuo.mutation.SetWeight(i) + return nuo +} + +// SetNillableWeight sets the "weight" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableWeight(i *int) *NodeUpdateOne { + if i != nil { + nuo.SetWeight(*i) + } + return nuo +} + +// AddWeight adds i to the "weight" field. +func (nuo *NodeUpdateOne) AddWeight(i int) *NodeUpdateOne { + nuo.mutation.AddWeight(i) + return nuo +} + +// AddStoragePolicyIDs adds the "storage_policy" edge to the StoragePolicy entity by IDs. +func (nuo *NodeUpdateOne) AddStoragePolicyIDs(ids ...int) *NodeUpdateOne { + nuo.mutation.AddStoragePolicyIDs(ids...) + return nuo +} + +// AddStoragePolicy adds the "storage_policy" edges to the StoragePolicy entity. +func (nuo *NodeUpdateOne) AddStoragePolicy(s ...*StoragePolicy) *NodeUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return nuo.AddStoragePolicyIDs(ids...) +} + +// Mutation returns the NodeMutation object of the builder. +func (nuo *NodeUpdateOne) Mutation() *NodeMutation { + return nuo.mutation +} + +// ClearStoragePolicy clears all "storage_policy" edges to the StoragePolicy entity. +func (nuo *NodeUpdateOne) ClearStoragePolicy() *NodeUpdateOne { + nuo.mutation.ClearStoragePolicy() + return nuo +} + +// RemoveStoragePolicyIDs removes the "storage_policy" edge to StoragePolicy entities by IDs. +func (nuo *NodeUpdateOne) RemoveStoragePolicyIDs(ids ...int) *NodeUpdateOne { + nuo.mutation.RemoveStoragePolicyIDs(ids...) + return nuo +} + +// RemoveStoragePolicy removes "storage_policy" edges to StoragePolicy entities. +func (nuo *NodeUpdateOne) RemoveStoragePolicy(s ...*StoragePolicy) *NodeUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return nuo.RemoveStoragePolicyIDs(ids...) +} + +// Where appends a list predicates to the NodeUpdate builder. +func (nuo *NodeUpdateOne) Where(ps ...predicate.Node) *NodeUpdateOne { + nuo.mutation.Where(ps...) + return nuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (nuo *NodeUpdateOne) Select(field string, fields ...string) *NodeUpdateOne { + nuo.fields = append([]string{field}, fields...) + return nuo +} + +// Save executes the query and returns the updated Node entity. +func (nuo *NodeUpdateOne) Save(ctx context.Context) (*Node, error) { + if err := nuo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, nuo.sqlSave, nuo.mutation, nuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (nuo *NodeUpdateOne) SaveX(ctx context.Context) *Node { + node, err := nuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (nuo *NodeUpdateOne) Exec(ctx context.Context) error { + _, err := nuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nuo *NodeUpdateOne) ExecX(ctx context.Context) { + if err := nuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (nuo *NodeUpdateOne) defaults() error { + if _, ok := nuo.mutation.UpdatedAt(); !ok { + if node.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized node.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := node.UpdateDefaultUpdatedAt() + nuo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (nuo *NodeUpdateOne) check() error { + if v, ok := nuo.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } + if v, ok := nuo.mutation.GetType(); ok { + if err := node.TypeValidator(v); err != nil { + return &ValidationError{Name: "type", err: fmt.Errorf(`ent: validator failed for field "Node.type": %w`, err)} + } + } + return nil +} + +func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (_node *Node, err error) { + if err := nuo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(node.Table, node.Columns, sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt)) + id, ok := nuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Node.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := nuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, node.FieldID) + for _, f := range fields { + if !node.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != node.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := nuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := nuo.mutation.UpdatedAt(); ok { + _spec.SetField(node.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := nuo.mutation.DeletedAt(); ok { + _spec.SetField(node.FieldDeletedAt, field.TypeTime, value) + } + if nuo.mutation.DeletedAtCleared() { + _spec.ClearField(node.FieldDeletedAt, field.TypeTime) + } + if value, ok := nuo.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + } + if value, ok := nuo.mutation.Name(); ok { + _spec.SetField(node.FieldName, field.TypeString, value) + } + if value, ok := nuo.mutation.GetType(); ok { + _spec.SetField(node.FieldType, field.TypeEnum, value) + } + if value, ok := nuo.mutation.Server(); ok { + _spec.SetField(node.FieldServer, field.TypeString, value) + } + if nuo.mutation.ServerCleared() { + _spec.ClearField(node.FieldServer, field.TypeString) + } + if value, ok := nuo.mutation.SlaveKey(); ok { + _spec.SetField(node.FieldSlaveKey, field.TypeString, value) + } + if nuo.mutation.SlaveKeyCleared() { + _spec.ClearField(node.FieldSlaveKey, field.TypeString) + } + if value, ok := nuo.mutation.Capabilities(); ok { + _spec.SetField(node.FieldCapabilities, field.TypeBytes, value) + } + if value, ok := nuo.mutation.Settings(); ok { + _spec.SetField(node.FieldSettings, field.TypeJSON, value) + } + if nuo.mutation.SettingsCleared() { + _spec.ClearField(node.FieldSettings, field.TypeJSON) + } + if value, ok := nuo.mutation.Weight(); ok { + _spec.SetField(node.FieldWeight, field.TypeInt, value) + } + if value, ok := nuo.mutation.AddedWeight(); ok { + _spec.AddField(node.FieldWeight, field.TypeInt, value) + } + if nuo.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nuo.mutation.RemovedStoragePolicyIDs(); len(nodes) > 0 && !nuo.mutation.StoragePolicyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nuo.mutation.StoragePolicyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.StoragePolicyTable, + Columns: []string{node.StoragePolicyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Node{config: nuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, nuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{node.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + nuo.mutation.done = true + return _node, nil +} diff --git a/ent/passkey.go b/ent/passkey.go new file mode 100644 index 00000000..51e653db --- /dev/null +++ b/ent/passkey.go @@ -0,0 +1,231 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/go-webauthn/webauthn/webauthn" +) + +// Passkey is the model entity for the Passkey schema. +type Passkey struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int `json:"user_id,omitempty"` + // CredentialID holds the value of the "credential_id" field. + CredentialID string `json:"credential_id,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Credential holds the value of the "credential" field. + Credential *webauthn.Credential `json:"-"` + // UsedAt holds the value of the "used_at" field. + UsedAt *time.Time `json:"used_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PasskeyQuery when eager-loading is set. + Edges PasskeyEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PasskeyEdges holds the relations/edges for other nodes in the graph. +type PasskeyEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PasskeyEdges) UserOrErr() (*User, error) { + if e.loadedTypes[0] { + if e.User == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.User, nil + } + return nil, &NotLoadedError{edge: "user"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Passkey) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case passkey.FieldCredential: + values[i] = new([]byte) + case passkey.FieldID, passkey.FieldUserID: + values[i] = new(sql.NullInt64) + case passkey.FieldCredentialID, passkey.FieldName: + values[i] = new(sql.NullString) + case passkey.FieldCreatedAt, passkey.FieldUpdatedAt, passkey.FieldDeletedAt, passkey.FieldUsedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Passkey fields. +func (pa *Passkey) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case passkey.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + pa.ID = int(value.Int64) + case passkey.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + pa.CreatedAt = value.Time + } + case passkey.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + pa.UpdatedAt = value.Time + } + case passkey.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + pa.DeletedAt = new(time.Time) + *pa.DeletedAt = value.Time + } + case passkey.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + pa.UserID = int(value.Int64) + } + case passkey.FieldCredentialID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field credential_id", values[i]) + } else if value.Valid { + pa.CredentialID = value.String + } + case passkey.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + pa.Name = value.String + } + case passkey.FieldCredential: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field credential", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &pa.Credential); err != nil { + return fmt.Errorf("unmarshal field credential: %w", err) + } + } + case passkey.FieldUsedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field used_at", values[i]) + } else if value.Valid { + pa.UsedAt = new(time.Time) + *pa.UsedAt = value.Time + } + default: + pa.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Passkey. +// This includes values selected through modifiers, order, etc. +func (pa *Passkey) Value(name string) (ent.Value, error) { + return pa.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the Passkey entity. +func (pa *Passkey) QueryUser() *UserQuery { + return NewPasskeyClient(pa.config).QueryUser(pa) +} + +// Update returns a builder for updating this Passkey. +// Note that you need to call Passkey.Unwrap() before calling this method if this Passkey +// was returned from a transaction, and the transaction was committed or rolled back. +func (pa *Passkey) Update() *PasskeyUpdateOne { + return NewPasskeyClient(pa.config).UpdateOne(pa) +} + +// Unwrap unwraps the Passkey entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (pa *Passkey) Unwrap() *Passkey { + _tx, ok := pa.config.driver.(*txDriver) + if !ok { + panic("ent: Passkey is not a transactional entity") + } + pa.config.driver = _tx.drv + return pa +} + +// String implements the fmt.Stringer. +func (pa *Passkey) String() string { + var builder strings.Builder + builder.WriteString("Passkey(") + builder.WriteString(fmt.Sprintf("id=%v, ", pa.ID)) + builder.WriteString("created_at=") + builder.WriteString(pa.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(pa.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := pa.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", pa.UserID)) + builder.WriteString(", ") + builder.WriteString("credential_id=") + builder.WriteString(pa.CredentialID) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(pa.Name) + builder.WriteString(", ") + builder.WriteString("credential=") + builder.WriteString(", ") + if v := pa.UsedAt; v != nil { + builder.WriteString("used_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// SetUser manually set the edge as loaded state. +func (e *Passkey) SetUser(v *User) { + e.Edges.User = v + e.Edges.loadedTypes[0] = true +} + +// Passkeys is a parsable slice of Passkey. +type Passkeys []*Passkey diff --git a/ent/passkey/passkey.go b/ent/passkey/passkey.go new file mode 100644 index 00000000..01824482 --- /dev/null +++ b/ent/passkey/passkey.go @@ -0,0 +1,141 @@ +// Code generated by ent, DO NOT EDIT. + +package passkey + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the passkey type in the database. + Label = "passkey" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldCredentialID holds the string denoting the credential_id field in the database. + FieldCredentialID = "credential_id" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldCredential holds the string denoting the credential field in the database. + FieldCredential = "credential" + // FieldUsedAt holds the string denoting the used_at field in the database. + FieldUsedAt = "used_at" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // Table holds the table name of the passkey in the database. + Table = "passkeys" + // UserTable is the table that holds the user relation/edge. + UserTable = "passkeys" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" +) + +// Columns holds all SQL columns for passkey fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldUserID, + FieldCredentialID, + FieldName, + FieldCredential, + FieldUsedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the Passkey queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByCredentialID orders the results by the credential_id field. +func ByCredentialID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCredentialID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByUsedAt orders the results by the used_at field. +func ByUsedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsedAt, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} diff --git a/ent/passkey/where.go b/ent/passkey/where.go new file mode 100644 index 00000000..3eeeac27 --- /dev/null +++ b/ent/passkey/where.go @@ -0,0 +1,459 @@ +// Code generated by ent, DO NOT EDIT. + +package passkey + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldDeletedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUserID, v)) +} + +// CredentialID applies equality check predicate on the "credential_id" field. It's identical to CredentialIDEQ. +func CredentialID(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldCredentialID, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldName, v)) +} + +// UsedAt applies equality check predicate on the "used_at" field. It's identical to UsedAtEQ. +func UsedAt(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUsedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Passkey { + return predicate.Passkey(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Passkey { + return predicate.Passkey(sql.FieldNotNull(FieldDeletedAt)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldUserID, vs...)) +} + +// CredentialIDEQ applies the EQ predicate on the "credential_id" field. +func CredentialIDEQ(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldCredentialID, v)) +} + +// CredentialIDNEQ applies the NEQ predicate on the "credential_id" field. +func CredentialIDNEQ(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldCredentialID, v)) +} + +// CredentialIDIn applies the In predicate on the "credential_id" field. +func CredentialIDIn(vs ...string) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldCredentialID, vs...)) +} + +// CredentialIDNotIn applies the NotIn predicate on the "credential_id" field. +func CredentialIDNotIn(vs ...string) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldCredentialID, vs...)) +} + +// CredentialIDGT applies the GT predicate on the "credential_id" field. +func CredentialIDGT(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldCredentialID, v)) +} + +// CredentialIDGTE applies the GTE predicate on the "credential_id" field. +func CredentialIDGTE(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldCredentialID, v)) +} + +// CredentialIDLT applies the LT predicate on the "credential_id" field. +func CredentialIDLT(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldCredentialID, v)) +} + +// CredentialIDLTE applies the LTE predicate on the "credential_id" field. +func CredentialIDLTE(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldCredentialID, v)) +} + +// CredentialIDContains applies the Contains predicate on the "credential_id" field. +func CredentialIDContains(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldContains(FieldCredentialID, v)) +} + +// CredentialIDHasPrefix applies the HasPrefix predicate on the "credential_id" field. +func CredentialIDHasPrefix(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldHasPrefix(FieldCredentialID, v)) +} + +// CredentialIDHasSuffix applies the HasSuffix predicate on the "credential_id" field. +func CredentialIDHasSuffix(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldHasSuffix(FieldCredentialID, v)) +} + +// CredentialIDEqualFold applies the EqualFold predicate on the "credential_id" field. +func CredentialIDEqualFold(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEqualFold(FieldCredentialID, v)) +} + +// CredentialIDContainsFold applies the ContainsFold predicate on the "credential_id" field. +func CredentialIDContainsFold(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldContainsFold(FieldCredentialID, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Passkey { + return predicate.Passkey(sql.FieldContainsFold(FieldName, v)) +} + +// UsedAtEQ applies the EQ predicate on the "used_at" field. +func UsedAtEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldEQ(FieldUsedAt, v)) +} + +// UsedAtNEQ applies the NEQ predicate on the "used_at" field. +func UsedAtNEQ(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNEQ(FieldUsedAt, v)) +} + +// UsedAtIn applies the In predicate on the "used_at" field. +func UsedAtIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldIn(FieldUsedAt, vs...)) +} + +// UsedAtNotIn applies the NotIn predicate on the "used_at" field. +func UsedAtNotIn(vs ...time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldNotIn(FieldUsedAt, vs...)) +} + +// UsedAtGT applies the GT predicate on the "used_at" field. +func UsedAtGT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGT(FieldUsedAt, v)) +} + +// UsedAtGTE applies the GTE predicate on the "used_at" field. +func UsedAtGTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldGTE(FieldUsedAt, v)) +} + +// UsedAtLT applies the LT predicate on the "used_at" field. +func UsedAtLT(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLT(FieldUsedAt, v)) +} + +// UsedAtLTE applies the LTE predicate on the "used_at" field. +func UsedAtLTE(v time.Time) predicate.Passkey { + return predicate.Passkey(sql.FieldLTE(FieldUsedAt, v)) +} + +// UsedAtIsNil applies the IsNil predicate on the "used_at" field. +func UsedAtIsNil() predicate.Passkey { + return predicate.Passkey(sql.FieldIsNull(FieldUsedAt)) +} + +// UsedAtNotNil applies the NotNil predicate on the "used_at" field. +func UsedAtNotNil() predicate.Passkey { + return predicate.Passkey(sql.FieldNotNull(FieldUsedAt)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.Passkey { + return predicate.Passkey(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.Passkey { + return predicate.Passkey(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Passkey) predicate.Passkey { + return predicate.Passkey(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Passkey) predicate.Passkey { + return predicate.Passkey(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Passkey) predicate.Passkey { + return predicate.Passkey(sql.NotPredicates(p)) +} diff --git a/ent/passkey_create.go b/ent/passkey_create.go new file mode 100644 index 00000000..69c71b80 --- /dev/null +++ b/ent/passkey_create.go @@ -0,0 +1,922 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/go-webauthn/webauthn/webauthn" +) + +// PasskeyCreate is the builder for creating a Passkey entity. +type PasskeyCreate struct { + config + mutation *PasskeyMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (pc *PasskeyCreate) SetCreatedAt(t time.Time) *PasskeyCreate { + pc.mutation.SetCreatedAt(t) + return pc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (pc *PasskeyCreate) SetNillableCreatedAt(t *time.Time) *PasskeyCreate { + if t != nil { + pc.SetCreatedAt(*t) + } + return pc +} + +// SetUpdatedAt sets the "updated_at" field. +func (pc *PasskeyCreate) SetUpdatedAt(t time.Time) *PasskeyCreate { + pc.mutation.SetUpdatedAt(t) + return pc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (pc *PasskeyCreate) SetNillableUpdatedAt(t *time.Time) *PasskeyCreate { + if t != nil { + pc.SetUpdatedAt(*t) + } + return pc +} + +// SetDeletedAt sets the "deleted_at" field. +func (pc *PasskeyCreate) SetDeletedAt(t time.Time) *PasskeyCreate { + pc.mutation.SetDeletedAt(t) + return pc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (pc *PasskeyCreate) SetNillableDeletedAt(t *time.Time) *PasskeyCreate { + if t != nil { + pc.SetDeletedAt(*t) + } + return pc +} + +// SetUserID sets the "user_id" field. +func (pc *PasskeyCreate) SetUserID(i int) *PasskeyCreate { + pc.mutation.SetUserID(i) + return pc +} + +// SetCredentialID sets the "credential_id" field. +func (pc *PasskeyCreate) SetCredentialID(s string) *PasskeyCreate { + pc.mutation.SetCredentialID(s) + return pc +} + +// SetName sets the "name" field. +func (pc *PasskeyCreate) SetName(s string) *PasskeyCreate { + pc.mutation.SetName(s) + return pc +} + +// SetCredential sets the "credential" field. +func (pc *PasskeyCreate) SetCredential(w *webauthn.Credential) *PasskeyCreate { + pc.mutation.SetCredential(w) + return pc +} + +// SetUsedAt sets the "used_at" field. +func (pc *PasskeyCreate) SetUsedAt(t time.Time) *PasskeyCreate { + pc.mutation.SetUsedAt(t) + return pc +} + +// SetNillableUsedAt sets the "used_at" field if the given value is not nil. +func (pc *PasskeyCreate) SetNillableUsedAt(t *time.Time) *PasskeyCreate { + if t != nil { + pc.SetUsedAt(*t) + } + return pc +} + +// SetUser sets the "user" edge to the User entity. +func (pc *PasskeyCreate) SetUser(u *User) *PasskeyCreate { + return pc.SetUserID(u.ID) +} + +// Mutation returns the PasskeyMutation object of the builder. +func (pc *PasskeyCreate) Mutation() *PasskeyMutation { + return pc.mutation +} + +// Save creates the Passkey in the database. +func (pc *PasskeyCreate) Save(ctx context.Context) (*Passkey, error) { + if err := pc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, pc.sqlSave, pc.mutation, pc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (pc *PasskeyCreate) SaveX(ctx context.Context) *Passkey { + v, err := pc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (pc *PasskeyCreate) Exec(ctx context.Context) error { + _, err := pc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pc *PasskeyCreate) ExecX(ctx context.Context) { + if err := pc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (pc *PasskeyCreate) defaults() error { + if _, ok := pc.mutation.CreatedAt(); !ok { + if passkey.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized passkey.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := passkey.DefaultCreatedAt() + pc.mutation.SetCreatedAt(v) + } + if _, ok := pc.mutation.UpdatedAt(); !ok { + if passkey.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized passkey.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := passkey.DefaultUpdatedAt() + pc.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (pc *PasskeyCreate) check() error { + if _, ok := pc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Passkey.created_at"`)} + } + if _, ok := pc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Passkey.updated_at"`)} + } + if _, ok := pc.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "Passkey.user_id"`)} + } + if _, ok := pc.mutation.CredentialID(); !ok { + return &ValidationError{Name: "credential_id", err: errors.New(`ent: missing required field "Passkey.credential_id"`)} + } + if _, ok := pc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Passkey.name"`)} + } + if _, ok := pc.mutation.Credential(); !ok { + return &ValidationError{Name: "credential", err: errors.New(`ent: missing required field "Passkey.credential"`)} + } + if _, ok := pc.mutation.UserID(); !ok { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "Passkey.user"`)} + } + return nil +} + +func (pc *PasskeyCreate) sqlSave(ctx context.Context) (*Passkey, error) { + if err := pc.check(); err != nil { + return nil, err + } + _node, _spec := pc.createSpec() + if err := sqlgraph.CreateNode(ctx, pc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + pc.mutation.id = &_node.ID + pc.mutation.done = true + return _node, nil +} + +func (pc *PasskeyCreate) createSpec() (*Passkey, *sqlgraph.CreateSpec) { + var ( + _node = &Passkey{config: pc.config} + _spec = sqlgraph.NewCreateSpec(passkey.Table, sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt)) + ) + + if id, ok := pc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = pc.conflict + if value, ok := pc.mutation.CreatedAt(); ok { + _spec.SetField(passkey.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := pc.mutation.UpdatedAt(); ok { + _spec.SetField(passkey.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := pc.mutation.DeletedAt(); ok { + _spec.SetField(passkey.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := pc.mutation.CredentialID(); ok { + _spec.SetField(passkey.FieldCredentialID, field.TypeString, value) + _node.CredentialID = value + } + if value, ok := pc.mutation.Name(); ok { + _spec.SetField(passkey.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := pc.mutation.Credential(); ok { + _spec.SetField(passkey.FieldCredential, field.TypeJSON, value) + _node.Credential = value + } + if value, ok := pc.mutation.UsedAt(); ok { + _spec.SetField(passkey.FieldUsedAt, field.TypeTime, value) + _node.UsedAt = &value + } + if nodes := pc.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: passkey.UserTable, + Columns: []string{passkey.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Passkey.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PasskeyUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (pc *PasskeyCreate) OnConflict(opts ...sql.ConflictOption) *PasskeyUpsertOne { + pc.conflict = opts + return &PasskeyUpsertOne{ + create: pc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (pc *PasskeyCreate) OnConflictColumns(columns ...string) *PasskeyUpsertOne { + pc.conflict = append(pc.conflict, sql.ConflictColumns(columns...)) + return &PasskeyUpsertOne{ + create: pc, + } +} + +type ( + // PasskeyUpsertOne is the builder for "upsert"-ing + // one Passkey node. + PasskeyUpsertOne struct { + create *PasskeyCreate + } + + // PasskeyUpsert is the "OnConflict" setter. + PasskeyUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PasskeyUpsert) SetUpdatedAt(v time.Time) *PasskeyUpsert { + u.Set(passkey.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateUpdatedAt() *PasskeyUpsert { + u.SetExcluded(passkey.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *PasskeyUpsert) SetDeletedAt(v time.Time) *PasskeyUpsert { + u.Set(passkey.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateDeletedAt() *PasskeyUpsert { + u.SetExcluded(passkey.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *PasskeyUpsert) ClearDeletedAt() *PasskeyUpsert { + u.SetNull(passkey.FieldDeletedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *PasskeyUpsert) SetUserID(v int) *PasskeyUpsert { + u.Set(passkey.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateUserID() *PasskeyUpsert { + u.SetExcluded(passkey.FieldUserID) + return u +} + +// SetCredentialID sets the "credential_id" field. +func (u *PasskeyUpsert) SetCredentialID(v string) *PasskeyUpsert { + u.Set(passkey.FieldCredentialID, v) + return u +} + +// UpdateCredentialID sets the "credential_id" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateCredentialID() *PasskeyUpsert { + u.SetExcluded(passkey.FieldCredentialID) + return u +} + +// SetName sets the "name" field. +func (u *PasskeyUpsert) SetName(v string) *PasskeyUpsert { + u.Set(passkey.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateName() *PasskeyUpsert { + u.SetExcluded(passkey.FieldName) + return u +} + +// SetCredential sets the "credential" field. +func (u *PasskeyUpsert) SetCredential(v *webauthn.Credential) *PasskeyUpsert { + u.Set(passkey.FieldCredential, v) + return u +} + +// UpdateCredential sets the "credential" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateCredential() *PasskeyUpsert { + u.SetExcluded(passkey.FieldCredential) + return u +} + +// SetUsedAt sets the "used_at" field. +func (u *PasskeyUpsert) SetUsedAt(v time.Time) *PasskeyUpsert { + u.Set(passkey.FieldUsedAt, v) + return u +} + +// UpdateUsedAt sets the "used_at" field to the value that was provided on create. +func (u *PasskeyUpsert) UpdateUsedAt() *PasskeyUpsert { + u.SetExcluded(passkey.FieldUsedAt) + return u +} + +// ClearUsedAt clears the value of the "used_at" field. +func (u *PasskeyUpsert) ClearUsedAt() *PasskeyUpsert { + u.SetNull(passkey.FieldUsedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PasskeyUpsertOne) UpdateNewValues() *PasskeyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(passkey.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PasskeyUpsertOne) Ignore() *PasskeyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PasskeyUpsertOne) DoNothing() *PasskeyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PasskeyCreate.OnConflict +// documentation for more info. +func (u *PasskeyUpsertOne) Update(set func(*PasskeyUpsert)) *PasskeyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PasskeyUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PasskeyUpsertOne) SetUpdatedAt(v time.Time) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateUpdatedAt() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *PasskeyUpsertOne) SetDeletedAt(v time.Time) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateDeletedAt() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *PasskeyUpsertOne) ClearDeletedAt() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *PasskeyUpsertOne) SetUserID(v int) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateUserID() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUserID() + }) +} + +// SetCredentialID sets the "credential_id" field. +func (u *PasskeyUpsertOne) SetCredentialID(v string) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetCredentialID(v) + }) +} + +// UpdateCredentialID sets the "credential_id" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateCredentialID() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateCredentialID() + }) +} + +// SetName sets the "name" field. +func (u *PasskeyUpsertOne) SetName(v string) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateName() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateName() + }) +} + +// SetCredential sets the "credential" field. +func (u *PasskeyUpsertOne) SetCredential(v *webauthn.Credential) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetCredential(v) + }) +} + +// UpdateCredential sets the "credential" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateCredential() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateCredential() + }) +} + +// SetUsedAt sets the "used_at" field. +func (u *PasskeyUpsertOne) SetUsedAt(v time.Time) *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.SetUsedAt(v) + }) +} + +// UpdateUsedAt sets the "used_at" field to the value that was provided on create. +func (u *PasskeyUpsertOne) UpdateUsedAt() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUsedAt() + }) +} + +// ClearUsedAt clears the value of the "used_at" field. +func (u *PasskeyUpsertOne) ClearUsedAt() *PasskeyUpsertOne { + return u.Update(func(s *PasskeyUpsert) { + s.ClearUsedAt() + }) +} + +// Exec executes the query. +func (u *PasskeyUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PasskeyCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PasskeyUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PasskeyUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PasskeyUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *PasskeyCreate) SetRawID(t int) *PasskeyCreate { + m.mutation.SetRawID(t) + return m +} + +// PasskeyCreateBulk is the builder for creating many Passkey entities in bulk. +type PasskeyCreateBulk struct { + config + err error + builders []*PasskeyCreate + conflict []sql.ConflictOption +} + +// Save creates the Passkey entities in the database. +func (pcb *PasskeyCreateBulk) Save(ctx context.Context) ([]*Passkey, error) { + if pcb.err != nil { + return nil, pcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(pcb.builders)) + nodes := make([]*Passkey, len(pcb.builders)) + mutators := make([]Mutator, len(pcb.builders)) + for i := range pcb.builders { + func(i int, root context.Context) { + builder := pcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PasskeyMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, pcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = pcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, pcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, pcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (pcb *PasskeyCreateBulk) SaveX(ctx context.Context) []*Passkey { + v, err := pcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (pcb *PasskeyCreateBulk) Exec(ctx context.Context) error { + _, err := pcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pcb *PasskeyCreateBulk) ExecX(ctx context.Context) { + if err := pcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Passkey.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PasskeyUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (pcb *PasskeyCreateBulk) OnConflict(opts ...sql.ConflictOption) *PasskeyUpsertBulk { + pcb.conflict = opts + return &PasskeyUpsertBulk{ + create: pcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (pcb *PasskeyCreateBulk) OnConflictColumns(columns ...string) *PasskeyUpsertBulk { + pcb.conflict = append(pcb.conflict, sql.ConflictColumns(columns...)) + return &PasskeyUpsertBulk{ + create: pcb, + } +} + +// PasskeyUpsertBulk is the builder for "upsert"-ing +// a bulk of Passkey nodes. +type PasskeyUpsertBulk struct { + create *PasskeyCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PasskeyUpsertBulk) UpdateNewValues() *PasskeyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(passkey.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Passkey.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PasskeyUpsertBulk) Ignore() *PasskeyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PasskeyUpsertBulk) DoNothing() *PasskeyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PasskeyCreateBulk.OnConflict +// documentation for more info. +func (u *PasskeyUpsertBulk) Update(set func(*PasskeyUpsert)) *PasskeyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PasskeyUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PasskeyUpsertBulk) SetUpdatedAt(v time.Time) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateUpdatedAt() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *PasskeyUpsertBulk) SetDeletedAt(v time.Time) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateDeletedAt() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *PasskeyUpsertBulk) ClearDeletedAt() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *PasskeyUpsertBulk) SetUserID(v int) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateUserID() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUserID() + }) +} + +// SetCredentialID sets the "credential_id" field. +func (u *PasskeyUpsertBulk) SetCredentialID(v string) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetCredentialID(v) + }) +} + +// UpdateCredentialID sets the "credential_id" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateCredentialID() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateCredentialID() + }) +} + +// SetName sets the "name" field. +func (u *PasskeyUpsertBulk) SetName(v string) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateName() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateName() + }) +} + +// SetCredential sets the "credential" field. +func (u *PasskeyUpsertBulk) SetCredential(v *webauthn.Credential) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetCredential(v) + }) +} + +// UpdateCredential sets the "credential" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateCredential() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateCredential() + }) +} + +// SetUsedAt sets the "used_at" field. +func (u *PasskeyUpsertBulk) SetUsedAt(v time.Time) *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.SetUsedAt(v) + }) +} + +// UpdateUsedAt sets the "used_at" field to the value that was provided on create. +func (u *PasskeyUpsertBulk) UpdateUsedAt() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.UpdateUsedAt() + }) +} + +// ClearUsedAt clears the value of the "used_at" field. +func (u *PasskeyUpsertBulk) ClearUsedAt() *PasskeyUpsertBulk { + return u.Update(func(s *PasskeyUpsert) { + s.ClearUsedAt() + }) +} + +// Exec executes the query. +func (u *PasskeyUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PasskeyCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PasskeyCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PasskeyUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/passkey_delete.go b/ent/passkey_delete.go new file mode 100644 index 00000000..a4a3bec7 --- /dev/null +++ b/ent/passkey_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// PasskeyDelete is the builder for deleting a Passkey entity. +type PasskeyDelete struct { + config + hooks []Hook + mutation *PasskeyMutation +} + +// Where appends a list predicates to the PasskeyDelete builder. +func (pd *PasskeyDelete) Where(ps ...predicate.Passkey) *PasskeyDelete { + pd.mutation.Where(ps...) + return pd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (pd *PasskeyDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, pd.sqlExec, pd.mutation, pd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (pd *PasskeyDelete) ExecX(ctx context.Context) int { + n, err := pd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (pd *PasskeyDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(passkey.Table, sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt)) + if ps := pd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, pd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + pd.mutation.done = true + return affected, err +} + +// PasskeyDeleteOne is the builder for deleting a single Passkey entity. +type PasskeyDeleteOne struct { + pd *PasskeyDelete +} + +// Where appends a list predicates to the PasskeyDelete builder. +func (pdo *PasskeyDeleteOne) Where(ps ...predicate.Passkey) *PasskeyDeleteOne { + pdo.pd.mutation.Where(ps...) + return pdo +} + +// Exec executes the deletion query. +func (pdo *PasskeyDeleteOne) Exec(ctx context.Context) error { + n, err := pdo.pd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{passkey.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (pdo *PasskeyDeleteOne) ExecX(ctx context.Context) { + if err := pdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/passkey_query.go b/ent/passkey_query.go new file mode 100644 index 00000000..2fc73c2d --- /dev/null +++ b/ent/passkey_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// PasskeyQuery is the builder for querying Passkey entities. +type PasskeyQuery struct { + config + ctx *QueryContext + order []passkey.OrderOption + inters []Interceptor + predicates []predicate.Passkey + withUser *UserQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PasskeyQuery builder. +func (pq *PasskeyQuery) Where(ps ...predicate.Passkey) *PasskeyQuery { + pq.predicates = append(pq.predicates, ps...) + return pq +} + +// Limit the number of records to be returned by this query. +func (pq *PasskeyQuery) Limit(limit int) *PasskeyQuery { + pq.ctx.Limit = &limit + return pq +} + +// Offset to start from. +func (pq *PasskeyQuery) Offset(offset int) *PasskeyQuery { + pq.ctx.Offset = &offset + return pq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (pq *PasskeyQuery) Unique(unique bool) *PasskeyQuery { + pq.ctx.Unique = &unique + return pq +} + +// Order specifies how the records should be ordered. +func (pq *PasskeyQuery) Order(o ...passkey.OrderOption) *PasskeyQuery { + pq.order = append(pq.order, o...) + return pq +} + +// QueryUser chains the current query on the "user" edge. +func (pq *PasskeyQuery) QueryUser() *UserQuery { + query := (&UserClient{config: pq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := pq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := pq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(passkey.Table, passkey.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, passkey.UserTable, passkey.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Passkey entity from the query. +// Returns a *NotFoundError when no Passkey was found. +func (pq *PasskeyQuery) First(ctx context.Context) (*Passkey, error) { + nodes, err := pq.Limit(1).All(setContextOp(ctx, pq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{passkey.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (pq *PasskeyQuery) FirstX(ctx context.Context) *Passkey { + node, err := pq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Passkey ID from the query. +// Returns a *NotFoundError when no Passkey ID was found. +func (pq *PasskeyQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = pq.Limit(1).IDs(setContextOp(ctx, pq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{passkey.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (pq *PasskeyQuery) FirstIDX(ctx context.Context) int { + id, err := pq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Passkey entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Passkey entity is found. +// Returns a *NotFoundError when no Passkey entities are found. +func (pq *PasskeyQuery) Only(ctx context.Context) (*Passkey, error) { + nodes, err := pq.Limit(2).All(setContextOp(ctx, pq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{passkey.Label} + default: + return nil, &NotSingularError{passkey.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (pq *PasskeyQuery) OnlyX(ctx context.Context) *Passkey { + node, err := pq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Passkey ID in the query. +// Returns a *NotSingularError when more than one Passkey ID is found. +// Returns a *NotFoundError when no entities are found. +func (pq *PasskeyQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = pq.Limit(2).IDs(setContextOp(ctx, pq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{passkey.Label} + default: + err = &NotSingularError{passkey.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (pq *PasskeyQuery) OnlyIDX(ctx context.Context) int { + id, err := pq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Passkeys. +func (pq *PasskeyQuery) All(ctx context.Context) ([]*Passkey, error) { + ctx = setContextOp(ctx, pq.ctx, "All") + if err := pq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Passkey, *PasskeyQuery]() + return withInterceptors[[]*Passkey](ctx, pq, qr, pq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (pq *PasskeyQuery) AllX(ctx context.Context) []*Passkey { + nodes, err := pq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Passkey IDs. +func (pq *PasskeyQuery) IDs(ctx context.Context) (ids []int, err error) { + if pq.ctx.Unique == nil && pq.path != nil { + pq.Unique(true) + } + ctx = setContextOp(ctx, pq.ctx, "IDs") + if err = pq.Select(passkey.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (pq *PasskeyQuery) IDsX(ctx context.Context) []int { + ids, err := pq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (pq *PasskeyQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, pq.ctx, "Count") + if err := pq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, pq, querierCount[*PasskeyQuery](), pq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (pq *PasskeyQuery) CountX(ctx context.Context) int { + count, err := pq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (pq *PasskeyQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, pq.ctx, "Exist") + switch _, err := pq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (pq *PasskeyQuery) ExistX(ctx context.Context) bool { + exist, err := pq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PasskeyQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (pq *PasskeyQuery) Clone() *PasskeyQuery { + if pq == nil { + return nil + } + return &PasskeyQuery{ + config: pq.config, + ctx: pq.ctx.Clone(), + order: append([]passkey.OrderOption{}, pq.order...), + inters: append([]Interceptor{}, pq.inters...), + predicates: append([]predicate.Passkey{}, pq.predicates...), + withUser: pq.withUser.Clone(), + // clone intermediate query. + sql: pq.sql.Clone(), + path: pq.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (pq *PasskeyQuery) WithUser(opts ...func(*UserQuery)) *PasskeyQuery { + query := (&UserClient{config: pq.config}).Query() + for _, opt := range opts { + opt(query) + } + pq.withUser = query + return pq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Passkey.Query(). +// GroupBy(passkey.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (pq *PasskeyQuery) GroupBy(field string, fields ...string) *PasskeyGroupBy { + pq.ctx.Fields = append([]string{field}, fields...) + grbuild := &PasskeyGroupBy{build: pq} + grbuild.flds = &pq.ctx.Fields + grbuild.label = passkey.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Passkey.Query(). +// Select(passkey.FieldCreatedAt). +// Scan(ctx, &v) +func (pq *PasskeyQuery) Select(fields ...string) *PasskeySelect { + pq.ctx.Fields = append(pq.ctx.Fields, fields...) + sbuild := &PasskeySelect{PasskeyQuery: pq} + sbuild.label = passkey.Label + sbuild.flds, sbuild.scan = &pq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PasskeySelect configured with the given aggregations. +func (pq *PasskeyQuery) Aggregate(fns ...AggregateFunc) *PasskeySelect { + return pq.Select().Aggregate(fns...) +} + +func (pq *PasskeyQuery) prepareQuery(ctx context.Context) error { + for _, inter := range pq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, pq); err != nil { + return err + } + } + } + for _, f := range pq.ctx.Fields { + if !passkey.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if pq.path != nil { + prev, err := pq.path(ctx) + if err != nil { + return err + } + pq.sql = prev + } + return nil +} + +func (pq *PasskeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Passkey, error) { + var ( + nodes = []*Passkey{} + _spec = pq.querySpec() + loadedTypes = [1]bool{ + pq.withUser != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Passkey).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Passkey{config: pq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := pq.withUser; query != nil { + if err := pq.loadUser(ctx, query, nodes, nil, + func(n *Passkey, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (pq *PasskeyQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Passkey, init func(*Passkey), assign func(*Passkey, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Passkey) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (pq *PasskeyQuery) sqlCount(ctx context.Context) (int, error) { + _spec := pq.querySpec() + _spec.Node.Columns = pq.ctx.Fields + if len(pq.ctx.Fields) > 0 { + _spec.Unique = pq.ctx.Unique != nil && *pq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, pq.driver, _spec) +} + +func (pq *PasskeyQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(passkey.Table, passkey.Columns, sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt)) + _spec.From = pq.sql + if unique := pq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if pq.path != nil { + _spec.Unique = true + } + if fields := pq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, passkey.FieldID) + for i := range fields { + if fields[i] != passkey.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if pq.withUser != nil { + _spec.Node.AddColumnOnce(passkey.FieldUserID) + } + } + if ps := pq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := pq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := pq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := pq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (pq *PasskeyQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(pq.driver.Dialect()) + t1 := builder.Table(passkey.Table) + columns := pq.ctx.Fields + if len(columns) == 0 { + columns = passkey.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if pq.sql != nil { + selector = pq.sql + selector.Select(selector.Columns(columns...)...) + } + if pq.ctx.Unique != nil && *pq.ctx.Unique { + selector.Distinct() + } + for _, p := range pq.predicates { + p(selector) + } + for _, p := range pq.order { + p(selector) + } + if offset := pq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := pq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// PasskeyGroupBy is the group-by builder for Passkey entities. +type PasskeyGroupBy struct { + selector + build *PasskeyQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (pgb *PasskeyGroupBy) Aggregate(fns ...AggregateFunc) *PasskeyGroupBy { + pgb.fns = append(pgb.fns, fns...) + return pgb +} + +// Scan applies the selector query and scans the result into the given value. +func (pgb *PasskeyGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, pgb.build.ctx, "GroupBy") + if err := pgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PasskeyQuery, *PasskeyGroupBy](ctx, pgb.build, pgb, pgb.build.inters, v) +} + +func (pgb *PasskeyGroupBy) sqlScan(ctx context.Context, root *PasskeyQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(pgb.fns)) + for _, fn := range pgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*pgb.flds)+len(pgb.fns)) + for _, f := range *pgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*pgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := pgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PasskeySelect is the builder for selecting fields of Passkey entities. +type PasskeySelect struct { + *PasskeyQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ps *PasskeySelect) Aggregate(fns ...AggregateFunc) *PasskeySelect { + ps.fns = append(ps.fns, fns...) + return ps +} + +// Scan applies the selector query and scans the result into the given value. +func (ps *PasskeySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ps.ctx, "Select") + if err := ps.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PasskeyQuery, *PasskeySelect](ctx, ps.PasskeyQuery, ps, ps.inters, v) +} + +func (ps *PasskeySelect) sqlScan(ctx context.Context, root *PasskeyQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ps.fns)) + for _, fn := range ps.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ps.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ps.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/passkey_update.go b/ent/passkey_update.go new file mode 100644 index 00000000..8da69ac1 --- /dev/null +++ b/ent/passkey_update.go @@ -0,0 +1,546 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/go-webauthn/webauthn/webauthn" +) + +// PasskeyUpdate is the builder for updating Passkey entities. +type PasskeyUpdate struct { + config + hooks []Hook + mutation *PasskeyMutation +} + +// Where appends a list predicates to the PasskeyUpdate builder. +func (pu *PasskeyUpdate) Where(ps ...predicate.Passkey) *PasskeyUpdate { + pu.mutation.Where(ps...) + return pu +} + +// SetUpdatedAt sets the "updated_at" field. +func (pu *PasskeyUpdate) SetUpdatedAt(t time.Time) *PasskeyUpdate { + pu.mutation.SetUpdatedAt(t) + return pu +} + +// SetDeletedAt sets the "deleted_at" field. +func (pu *PasskeyUpdate) SetDeletedAt(t time.Time) *PasskeyUpdate { + pu.mutation.SetDeletedAt(t) + return pu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (pu *PasskeyUpdate) SetNillableDeletedAt(t *time.Time) *PasskeyUpdate { + if t != nil { + pu.SetDeletedAt(*t) + } + return pu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (pu *PasskeyUpdate) ClearDeletedAt() *PasskeyUpdate { + pu.mutation.ClearDeletedAt() + return pu +} + +// SetUserID sets the "user_id" field. +func (pu *PasskeyUpdate) SetUserID(i int) *PasskeyUpdate { + pu.mutation.SetUserID(i) + return pu +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (pu *PasskeyUpdate) SetNillableUserID(i *int) *PasskeyUpdate { + if i != nil { + pu.SetUserID(*i) + } + return pu +} + +// SetCredentialID sets the "credential_id" field. +func (pu *PasskeyUpdate) SetCredentialID(s string) *PasskeyUpdate { + pu.mutation.SetCredentialID(s) + return pu +} + +// SetNillableCredentialID sets the "credential_id" field if the given value is not nil. +func (pu *PasskeyUpdate) SetNillableCredentialID(s *string) *PasskeyUpdate { + if s != nil { + pu.SetCredentialID(*s) + } + return pu +} + +// SetName sets the "name" field. +func (pu *PasskeyUpdate) SetName(s string) *PasskeyUpdate { + pu.mutation.SetName(s) + return pu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (pu *PasskeyUpdate) SetNillableName(s *string) *PasskeyUpdate { + if s != nil { + pu.SetName(*s) + } + return pu +} + +// SetCredential sets the "credential" field. +func (pu *PasskeyUpdate) SetCredential(w *webauthn.Credential) *PasskeyUpdate { + pu.mutation.SetCredential(w) + return pu +} + +// SetUsedAt sets the "used_at" field. +func (pu *PasskeyUpdate) SetUsedAt(t time.Time) *PasskeyUpdate { + pu.mutation.SetUsedAt(t) + return pu +} + +// SetNillableUsedAt sets the "used_at" field if the given value is not nil. +func (pu *PasskeyUpdate) SetNillableUsedAt(t *time.Time) *PasskeyUpdate { + if t != nil { + pu.SetUsedAt(*t) + } + return pu +} + +// ClearUsedAt clears the value of the "used_at" field. +func (pu *PasskeyUpdate) ClearUsedAt() *PasskeyUpdate { + pu.mutation.ClearUsedAt() + return pu +} + +// SetUser sets the "user" edge to the User entity. +func (pu *PasskeyUpdate) SetUser(u *User) *PasskeyUpdate { + return pu.SetUserID(u.ID) +} + +// Mutation returns the PasskeyMutation object of the builder. +func (pu *PasskeyUpdate) Mutation() *PasskeyMutation { + return pu.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (pu *PasskeyUpdate) ClearUser() *PasskeyUpdate { + pu.mutation.ClearUser() + return pu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (pu *PasskeyUpdate) Save(ctx context.Context) (int, error) { + if err := pu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, pu.sqlSave, pu.mutation, pu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (pu *PasskeyUpdate) SaveX(ctx context.Context) int { + affected, err := pu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (pu *PasskeyUpdate) Exec(ctx context.Context) error { + _, err := pu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (pu *PasskeyUpdate) ExecX(ctx context.Context) { + if err := pu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (pu *PasskeyUpdate) defaults() error { + if _, ok := pu.mutation.UpdatedAt(); !ok { + if passkey.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized passkey.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := passkey.UpdateDefaultUpdatedAt() + pu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (pu *PasskeyUpdate) check() error { + if _, ok := pu.mutation.UserID(); pu.mutation.UserCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Passkey.user"`) + } + return nil +} + +func (pu *PasskeyUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := pu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(passkey.Table, passkey.Columns, sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt)) + if ps := pu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := pu.mutation.UpdatedAt(); ok { + _spec.SetField(passkey.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := pu.mutation.DeletedAt(); ok { + _spec.SetField(passkey.FieldDeletedAt, field.TypeTime, value) + } + if pu.mutation.DeletedAtCleared() { + _spec.ClearField(passkey.FieldDeletedAt, field.TypeTime) + } + if value, ok := pu.mutation.CredentialID(); ok { + _spec.SetField(passkey.FieldCredentialID, field.TypeString, value) + } + if value, ok := pu.mutation.Name(); ok { + _spec.SetField(passkey.FieldName, field.TypeString, value) + } + if value, ok := pu.mutation.Credential(); ok { + _spec.SetField(passkey.FieldCredential, field.TypeJSON, value) + } + if value, ok := pu.mutation.UsedAt(); ok { + _spec.SetField(passkey.FieldUsedAt, field.TypeTime, value) + } + if pu.mutation.UsedAtCleared() { + _spec.ClearField(passkey.FieldUsedAt, field.TypeTime) + } + if pu.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: passkey.UserTable, + Columns: []string{passkey.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := pu.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: passkey.UserTable, + Columns: []string{passkey.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, pu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{passkey.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + pu.mutation.done = true + return n, nil +} + +// PasskeyUpdateOne is the builder for updating a single Passkey entity. +type PasskeyUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PasskeyMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (puo *PasskeyUpdateOne) SetUpdatedAt(t time.Time) *PasskeyUpdateOne { + puo.mutation.SetUpdatedAt(t) + return puo +} + +// SetDeletedAt sets the "deleted_at" field. +func (puo *PasskeyUpdateOne) SetDeletedAt(t time.Time) *PasskeyUpdateOne { + puo.mutation.SetDeletedAt(t) + return puo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (puo *PasskeyUpdateOne) SetNillableDeletedAt(t *time.Time) *PasskeyUpdateOne { + if t != nil { + puo.SetDeletedAt(*t) + } + return puo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (puo *PasskeyUpdateOne) ClearDeletedAt() *PasskeyUpdateOne { + puo.mutation.ClearDeletedAt() + return puo +} + +// SetUserID sets the "user_id" field. +func (puo *PasskeyUpdateOne) SetUserID(i int) *PasskeyUpdateOne { + puo.mutation.SetUserID(i) + return puo +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (puo *PasskeyUpdateOne) SetNillableUserID(i *int) *PasskeyUpdateOne { + if i != nil { + puo.SetUserID(*i) + } + return puo +} + +// SetCredentialID sets the "credential_id" field. +func (puo *PasskeyUpdateOne) SetCredentialID(s string) *PasskeyUpdateOne { + puo.mutation.SetCredentialID(s) + return puo +} + +// SetNillableCredentialID sets the "credential_id" field if the given value is not nil. +func (puo *PasskeyUpdateOne) SetNillableCredentialID(s *string) *PasskeyUpdateOne { + if s != nil { + puo.SetCredentialID(*s) + } + return puo +} + +// SetName sets the "name" field. +func (puo *PasskeyUpdateOne) SetName(s string) *PasskeyUpdateOne { + puo.mutation.SetName(s) + return puo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (puo *PasskeyUpdateOne) SetNillableName(s *string) *PasskeyUpdateOne { + if s != nil { + puo.SetName(*s) + } + return puo +} + +// SetCredential sets the "credential" field. +func (puo *PasskeyUpdateOne) SetCredential(w *webauthn.Credential) *PasskeyUpdateOne { + puo.mutation.SetCredential(w) + return puo +} + +// SetUsedAt sets the "used_at" field. +func (puo *PasskeyUpdateOne) SetUsedAt(t time.Time) *PasskeyUpdateOne { + puo.mutation.SetUsedAt(t) + return puo +} + +// SetNillableUsedAt sets the "used_at" field if the given value is not nil. +func (puo *PasskeyUpdateOne) SetNillableUsedAt(t *time.Time) *PasskeyUpdateOne { + if t != nil { + puo.SetUsedAt(*t) + } + return puo +} + +// ClearUsedAt clears the value of the "used_at" field. +func (puo *PasskeyUpdateOne) ClearUsedAt() *PasskeyUpdateOne { + puo.mutation.ClearUsedAt() + return puo +} + +// SetUser sets the "user" edge to the User entity. +func (puo *PasskeyUpdateOne) SetUser(u *User) *PasskeyUpdateOne { + return puo.SetUserID(u.ID) +} + +// Mutation returns the PasskeyMutation object of the builder. +func (puo *PasskeyUpdateOne) Mutation() *PasskeyMutation { + return puo.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (puo *PasskeyUpdateOne) ClearUser() *PasskeyUpdateOne { + puo.mutation.ClearUser() + return puo +} + +// Where appends a list predicates to the PasskeyUpdate builder. +func (puo *PasskeyUpdateOne) Where(ps ...predicate.Passkey) *PasskeyUpdateOne { + puo.mutation.Where(ps...) + return puo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (puo *PasskeyUpdateOne) Select(field string, fields ...string) *PasskeyUpdateOne { + puo.fields = append([]string{field}, fields...) + return puo +} + +// Save executes the query and returns the updated Passkey entity. +func (puo *PasskeyUpdateOne) Save(ctx context.Context) (*Passkey, error) { + if err := puo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, puo.sqlSave, puo.mutation, puo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (puo *PasskeyUpdateOne) SaveX(ctx context.Context) *Passkey { + node, err := puo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (puo *PasskeyUpdateOne) Exec(ctx context.Context) error { + _, err := puo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (puo *PasskeyUpdateOne) ExecX(ctx context.Context) { + if err := puo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (puo *PasskeyUpdateOne) defaults() error { + if _, ok := puo.mutation.UpdatedAt(); !ok { + if passkey.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized passkey.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := passkey.UpdateDefaultUpdatedAt() + puo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (puo *PasskeyUpdateOne) check() error { + if _, ok := puo.mutation.UserID(); puo.mutation.UserCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "Passkey.user"`) + } + return nil +} + +func (puo *PasskeyUpdateOne) sqlSave(ctx context.Context) (_node *Passkey, err error) { + if err := puo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(passkey.Table, passkey.Columns, sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt)) + id, ok := puo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Passkey.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := puo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, passkey.FieldID) + for _, f := range fields { + if !passkey.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != passkey.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := puo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := puo.mutation.UpdatedAt(); ok { + _spec.SetField(passkey.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := puo.mutation.DeletedAt(); ok { + _spec.SetField(passkey.FieldDeletedAt, field.TypeTime, value) + } + if puo.mutation.DeletedAtCleared() { + _spec.ClearField(passkey.FieldDeletedAt, field.TypeTime) + } + if value, ok := puo.mutation.CredentialID(); ok { + _spec.SetField(passkey.FieldCredentialID, field.TypeString, value) + } + if value, ok := puo.mutation.Name(); ok { + _spec.SetField(passkey.FieldName, field.TypeString, value) + } + if value, ok := puo.mutation.Credential(); ok { + _spec.SetField(passkey.FieldCredential, field.TypeJSON, value) + } + if value, ok := puo.mutation.UsedAt(); ok { + _spec.SetField(passkey.FieldUsedAt, field.TypeTime, value) + } + if puo.mutation.UsedAtCleared() { + _spec.ClearField(passkey.FieldUsedAt, field.TypeTime) + } + if puo.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: passkey.UserTable, + Columns: []string{passkey.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := puo.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: passkey.UserTable, + Columns: []string{passkey.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Passkey{config: puo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, puo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{passkey.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + puo.mutation.done = true + return _node, nil +} diff --git a/ent/predicate/predicate.go b/ent/predicate/predicate.go new file mode 100644 index 00000000..3ee71805 --- /dev/null +++ b/ent/predicate/predicate.go @@ -0,0 +1,46 @@ +// Code generated by ent, DO NOT EDIT. + +package predicate + +import ( + "entgo.io/ent/dialect/sql" +) + +// DavAccount is the predicate function for davaccount builders. +type DavAccount func(*sql.Selector) + +// DirectLink is the predicate function for directlink builders. +type DirectLink func(*sql.Selector) + +// Entity is the predicate function for entity builders. +type Entity func(*sql.Selector) + +// File is the predicate function for file builders. +type File func(*sql.Selector) + +// Group is the predicate function for group builders. +type Group func(*sql.Selector) + +// Metadata is the predicate function for metadata builders. +type Metadata func(*sql.Selector) + +// Node is the predicate function for node builders. +type Node func(*sql.Selector) + +// Passkey is the predicate function for passkey builders. +type Passkey func(*sql.Selector) + +// Setting is the predicate function for setting builders. +type Setting func(*sql.Selector) + +// Share is the predicate function for share builders. +type Share func(*sql.Selector) + +// StoragePolicy is the predicate function for storagepolicy builders. +type StoragePolicy func(*sql.Selector) + +// Task is the predicate function for task builders. +type Task func(*sql.Selector) + +// User is the predicate function for user builders. +type User func(*sql.Selector) diff --git a/ent/runtime.go b/ent/runtime.go new file mode 100644 index 00000000..ec5f29f5 --- /dev/null +++ b/ent/runtime.go @@ -0,0 +1,5 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +// The schema-stitching logic is generated in github.com/cloudreve/Cloudreve/v4/ent/runtime/runtime.go diff --git a/ent/runtime/runtime.go b/ent/runtime/runtime.go new file mode 100644 index 00000000..cdc8fa56 --- /dev/null +++ b/ent/runtime/runtime.go @@ -0,0 +1,337 @@ +// Code generated by ent, DO NOT EDIT. + +package runtime + +import ( + "time" + + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/schema" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// The init function reads all schema descriptors with runtime code +// (default values, validators, hooks and policies) and stitches it +// to their package variables. +func init() { + davaccountMixin := schema.DavAccount{}.Mixin() + davaccountMixinHooks0 := davaccountMixin[0].Hooks() + davaccount.Hooks[0] = davaccountMixinHooks0[0] + davaccountMixinInters0 := davaccountMixin[0].Interceptors() + davaccount.Interceptors[0] = davaccountMixinInters0[0] + davaccountMixinFields0 := davaccountMixin[0].Fields() + _ = davaccountMixinFields0 + davaccountFields := schema.DavAccount{}.Fields() + _ = davaccountFields + // davaccountDescCreatedAt is the schema descriptor for created_at field. + davaccountDescCreatedAt := davaccountMixinFields0[0].Descriptor() + // davaccount.DefaultCreatedAt holds the default value on creation for the created_at field. + davaccount.DefaultCreatedAt = davaccountDescCreatedAt.Default.(func() time.Time) + // davaccountDescUpdatedAt is the schema descriptor for updated_at field. + davaccountDescUpdatedAt := davaccountMixinFields0[1].Descriptor() + // davaccount.DefaultUpdatedAt holds the default value on creation for the updated_at field. + davaccount.DefaultUpdatedAt = davaccountDescUpdatedAt.Default.(func() time.Time) + // davaccount.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + davaccount.UpdateDefaultUpdatedAt = davaccountDescUpdatedAt.UpdateDefault.(func() time.Time) + directlinkMixin := schema.DirectLink{}.Mixin() + directlinkMixinHooks0 := directlinkMixin[0].Hooks() + directlink.Hooks[0] = directlinkMixinHooks0[0] + directlinkMixinInters0 := directlinkMixin[0].Interceptors() + directlink.Interceptors[0] = directlinkMixinInters0[0] + directlinkMixinFields0 := directlinkMixin[0].Fields() + _ = directlinkMixinFields0 + directlinkFields := schema.DirectLink{}.Fields() + _ = directlinkFields + // directlinkDescCreatedAt is the schema descriptor for created_at field. + directlinkDescCreatedAt := directlinkMixinFields0[0].Descriptor() + // directlink.DefaultCreatedAt holds the default value on creation for the created_at field. + directlink.DefaultCreatedAt = directlinkDescCreatedAt.Default.(func() time.Time) + // directlinkDescUpdatedAt is the schema descriptor for updated_at field. + directlinkDescUpdatedAt := directlinkMixinFields0[1].Descriptor() + // directlink.DefaultUpdatedAt holds the default value on creation for the updated_at field. + directlink.DefaultUpdatedAt = directlinkDescUpdatedAt.Default.(func() time.Time) + // directlink.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + directlink.UpdateDefaultUpdatedAt = directlinkDescUpdatedAt.UpdateDefault.(func() time.Time) + entityMixin := schema.Entity{}.Mixin() + entityMixinHooks0 := entityMixin[0].Hooks() + entity.Hooks[0] = entityMixinHooks0[0] + entityMixinInters0 := entityMixin[0].Interceptors() + entity.Interceptors[0] = entityMixinInters0[0] + entityMixinFields0 := entityMixin[0].Fields() + _ = entityMixinFields0 + entityFields := schema.Entity{}.Fields() + _ = entityFields + // entityDescCreatedAt is the schema descriptor for created_at field. + entityDescCreatedAt := entityMixinFields0[0].Descriptor() + // entity.DefaultCreatedAt holds the default value on creation for the created_at field. + entity.DefaultCreatedAt = entityDescCreatedAt.Default.(func() time.Time) + // entityDescUpdatedAt is the schema descriptor for updated_at field. + entityDescUpdatedAt := entityMixinFields0[1].Descriptor() + // entity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + entity.DefaultUpdatedAt = entityDescUpdatedAt.Default.(func() time.Time) + // entity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + entity.UpdateDefaultUpdatedAt = entityDescUpdatedAt.UpdateDefault.(func() time.Time) + // entityDescReferenceCount is the schema descriptor for reference_count field. + entityDescReferenceCount := entityFields[3].Descriptor() + // entity.DefaultReferenceCount holds the default value on creation for the reference_count field. + entity.DefaultReferenceCount = entityDescReferenceCount.Default.(int) + fileMixin := schema.File{}.Mixin() + fileMixinHooks0 := fileMixin[0].Hooks() + file.Hooks[0] = fileMixinHooks0[0] + fileMixinInters0 := fileMixin[0].Interceptors() + file.Interceptors[0] = fileMixinInters0[0] + fileMixinFields0 := fileMixin[0].Fields() + _ = fileMixinFields0 + fileFields := schema.File{}.Fields() + _ = fileFields + // fileDescCreatedAt is the schema descriptor for created_at field. + fileDescCreatedAt := fileMixinFields0[0].Descriptor() + // file.DefaultCreatedAt holds the default value on creation for the created_at field. + file.DefaultCreatedAt = fileDescCreatedAt.Default.(func() time.Time) + // fileDescUpdatedAt is the schema descriptor for updated_at field. + fileDescUpdatedAt := fileMixinFields0[1].Descriptor() + // file.DefaultUpdatedAt holds the default value on creation for the updated_at field. + file.DefaultUpdatedAt = fileDescUpdatedAt.Default.(func() time.Time) + // file.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + file.UpdateDefaultUpdatedAt = fileDescUpdatedAt.UpdateDefault.(func() time.Time) + // fileDescSize is the schema descriptor for size field. + fileDescSize := fileFields[3].Descriptor() + // file.DefaultSize holds the default value on creation for the size field. + file.DefaultSize = fileDescSize.Default.(int64) + // fileDescIsSymbolic is the schema descriptor for is_symbolic field. + fileDescIsSymbolic := fileFields[6].Descriptor() + // file.DefaultIsSymbolic holds the default value on creation for the is_symbolic field. + file.DefaultIsSymbolic = fileDescIsSymbolic.Default.(bool) + groupMixin := schema.Group{}.Mixin() + groupMixinHooks0 := groupMixin[0].Hooks() + group.Hooks[0] = groupMixinHooks0[0] + groupMixinInters0 := groupMixin[0].Interceptors() + group.Interceptors[0] = groupMixinInters0[0] + groupMixinFields0 := groupMixin[0].Fields() + _ = groupMixinFields0 + groupFields := schema.Group{}.Fields() + _ = groupFields + // groupDescCreatedAt is the schema descriptor for created_at field. + groupDescCreatedAt := groupMixinFields0[0].Descriptor() + // group.DefaultCreatedAt holds the default value on creation for the created_at field. + group.DefaultCreatedAt = groupDescCreatedAt.Default.(func() time.Time) + // groupDescUpdatedAt is the schema descriptor for updated_at field. + groupDescUpdatedAt := groupMixinFields0[1].Descriptor() + // group.DefaultUpdatedAt holds the default value on creation for the updated_at field. + group.DefaultUpdatedAt = groupDescUpdatedAt.Default.(func() time.Time) + // group.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + group.UpdateDefaultUpdatedAt = groupDescUpdatedAt.UpdateDefault.(func() time.Time) + // groupDescSettings is the schema descriptor for settings field. + groupDescSettings := groupFields[4].Descriptor() + // group.DefaultSettings holds the default value on creation for the settings field. + group.DefaultSettings = groupDescSettings.Default.(*types.GroupSetting) + metadataMixin := schema.Metadata{}.Mixin() + metadataMixinHooks0 := metadataMixin[0].Hooks() + metadata.Hooks[0] = metadataMixinHooks0[0] + metadataMixinInters0 := metadataMixin[0].Interceptors() + metadata.Interceptors[0] = metadataMixinInters0[0] + metadataMixinFields0 := metadataMixin[0].Fields() + _ = metadataMixinFields0 + metadataFields := schema.Metadata{}.Fields() + _ = metadataFields + // metadataDescCreatedAt is the schema descriptor for created_at field. + metadataDescCreatedAt := metadataMixinFields0[0].Descriptor() + // metadata.DefaultCreatedAt holds the default value on creation for the created_at field. + metadata.DefaultCreatedAt = metadataDescCreatedAt.Default.(func() time.Time) + // metadataDescUpdatedAt is the schema descriptor for updated_at field. + metadataDescUpdatedAt := metadataMixinFields0[1].Descriptor() + // metadata.DefaultUpdatedAt holds the default value on creation for the updated_at field. + metadata.DefaultUpdatedAt = metadataDescUpdatedAt.Default.(func() time.Time) + // metadata.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + metadata.UpdateDefaultUpdatedAt = metadataDescUpdatedAt.UpdateDefault.(func() time.Time) + // metadataDescIsPublic is the schema descriptor for is_public field. + metadataDescIsPublic := metadataFields[3].Descriptor() + // metadata.DefaultIsPublic holds the default value on creation for the is_public field. + metadata.DefaultIsPublic = metadataDescIsPublic.Default.(bool) + nodeMixin := schema.Node{}.Mixin() + nodeMixinHooks0 := nodeMixin[0].Hooks() + node.Hooks[0] = nodeMixinHooks0[0] + nodeMixinInters0 := nodeMixin[0].Interceptors() + node.Interceptors[0] = nodeMixinInters0[0] + nodeMixinFields0 := nodeMixin[0].Fields() + _ = nodeMixinFields0 + nodeFields := schema.Node{}.Fields() + _ = nodeFields + // nodeDescCreatedAt is the schema descriptor for created_at field. + nodeDescCreatedAt := nodeMixinFields0[0].Descriptor() + // node.DefaultCreatedAt holds the default value on creation for the created_at field. + node.DefaultCreatedAt = nodeDescCreatedAt.Default.(func() time.Time) + // nodeDescUpdatedAt is the schema descriptor for updated_at field. + nodeDescUpdatedAt := nodeMixinFields0[1].Descriptor() + // node.DefaultUpdatedAt holds the default value on creation for the updated_at field. + node.DefaultUpdatedAt = nodeDescUpdatedAt.Default.(func() time.Time) + // node.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + node.UpdateDefaultUpdatedAt = nodeDescUpdatedAt.UpdateDefault.(func() time.Time) + // nodeDescSettings is the schema descriptor for settings field. + nodeDescSettings := nodeFields[6].Descriptor() + // node.DefaultSettings holds the default value on creation for the settings field. + node.DefaultSettings = nodeDescSettings.Default.(*types.NodeSetting) + // nodeDescWeight is the schema descriptor for weight field. + nodeDescWeight := nodeFields[7].Descriptor() + // node.DefaultWeight holds the default value on creation for the weight field. + node.DefaultWeight = nodeDescWeight.Default.(int) + passkeyMixin := schema.Passkey{}.Mixin() + passkeyMixinHooks0 := passkeyMixin[0].Hooks() + passkey.Hooks[0] = passkeyMixinHooks0[0] + passkeyMixinInters0 := passkeyMixin[0].Interceptors() + passkey.Interceptors[0] = passkeyMixinInters0[0] + passkeyMixinFields0 := passkeyMixin[0].Fields() + _ = passkeyMixinFields0 + passkeyFields := schema.Passkey{}.Fields() + _ = passkeyFields + // passkeyDescCreatedAt is the schema descriptor for created_at field. + passkeyDescCreatedAt := passkeyMixinFields0[0].Descriptor() + // passkey.DefaultCreatedAt holds the default value on creation for the created_at field. + passkey.DefaultCreatedAt = passkeyDescCreatedAt.Default.(func() time.Time) + // passkeyDescUpdatedAt is the schema descriptor for updated_at field. + passkeyDescUpdatedAt := passkeyMixinFields0[1].Descriptor() + // passkey.DefaultUpdatedAt holds the default value on creation for the updated_at field. + passkey.DefaultUpdatedAt = passkeyDescUpdatedAt.Default.(func() time.Time) + // passkey.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + passkey.UpdateDefaultUpdatedAt = passkeyDescUpdatedAt.UpdateDefault.(func() time.Time) + settingMixin := schema.Setting{}.Mixin() + settingMixinHooks0 := settingMixin[0].Hooks() + setting.Hooks[0] = settingMixinHooks0[0] + settingMixinInters0 := settingMixin[0].Interceptors() + setting.Interceptors[0] = settingMixinInters0[0] + settingMixinFields0 := settingMixin[0].Fields() + _ = settingMixinFields0 + settingFields := schema.Setting{}.Fields() + _ = settingFields + // settingDescCreatedAt is the schema descriptor for created_at field. + settingDescCreatedAt := settingMixinFields0[0].Descriptor() + // setting.DefaultCreatedAt holds the default value on creation for the created_at field. + setting.DefaultCreatedAt = settingDescCreatedAt.Default.(func() time.Time) + // settingDescUpdatedAt is the schema descriptor for updated_at field. + settingDescUpdatedAt := settingMixinFields0[1].Descriptor() + // setting.DefaultUpdatedAt holds the default value on creation for the updated_at field. + setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) + // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + shareMixin := schema.Share{}.Mixin() + shareMixinHooks0 := shareMixin[0].Hooks() + share.Hooks[0] = shareMixinHooks0[0] + shareMixinInters0 := shareMixin[0].Interceptors() + share.Interceptors[0] = shareMixinInters0[0] + shareMixinFields0 := shareMixin[0].Fields() + _ = shareMixinFields0 + shareFields := schema.Share{}.Fields() + _ = shareFields + // shareDescCreatedAt is the schema descriptor for created_at field. + shareDescCreatedAt := shareMixinFields0[0].Descriptor() + // share.DefaultCreatedAt holds the default value on creation for the created_at field. + share.DefaultCreatedAt = shareDescCreatedAt.Default.(func() time.Time) + // shareDescUpdatedAt is the schema descriptor for updated_at field. + shareDescUpdatedAt := shareMixinFields0[1].Descriptor() + // share.DefaultUpdatedAt holds the default value on creation for the updated_at field. + share.DefaultUpdatedAt = shareDescUpdatedAt.Default.(func() time.Time) + // share.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + share.UpdateDefaultUpdatedAt = shareDescUpdatedAt.UpdateDefault.(func() time.Time) + // shareDescViews is the schema descriptor for views field. + shareDescViews := shareFields[1].Descriptor() + // share.DefaultViews holds the default value on creation for the views field. + share.DefaultViews = shareDescViews.Default.(int) + // shareDescDownloads is the schema descriptor for downloads field. + shareDescDownloads := shareFields[2].Descriptor() + // share.DefaultDownloads holds the default value on creation for the downloads field. + share.DefaultDownloads = shareDescDownloads.Default.(int) + storagepolicyMixin := schema.StoragePolicy{}.Mixin() + storagepolicyMixinHooks0 := storagepolicyMixin[0].Hooks() + storagepolicy.Hooks[0] = storagepolicyMixinHooks0[0] + storagepolicyMixinInters0 := storagepolicyMixin[0].Interceptors() + storagepolicy.Interceptors[0] = storagepolicyMixinInters0[0] + storagepolicyMixinFields0 := storagepolicyMixin[0].Fields() + _ = storagepolicyMixinFields0 + storagepolicyFields := schema.StoragePolicy{}.Fields() + _ = storagepolicyFields + // storagepolicyDescCreatedAt is the schema descriptor for created_at field. + storagepolicyDescCreatedAt := storagepolicyMixinFields0[0].Descriptor() + // storagepolicy.DefaultCreatedAt holds the default value on creation for the created_at field. + storagepolicy.DefaultCreatedAt = storagepolicyDescCreatedAt.Default.(func() time.Time) + // storagepolicyDescUpdatedAt is the schema descriptor for updated_at field. + storagepolicyDescUpdatedAt := storagepolicyMixinFields0[1].Descriptor() + // storagepolicy.DefaultUpdatedAt holds the default value on creation for the updated_at field. + storagepolicy.DefaultUpdatedAt = storagepolicyDescUpdatedAt.Default.(func() time.Time) + // storagepolicy.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + storagepolicy.UpdateDefaultUpdatedAt = storagepolicyDescUpdatedAt.UpdateDefault.(func() time.Time) + // storagepolicyDescSettings is the schema descriptor for settings field. + storagepolicyDescSettings := storagepolicyFields[10].Descriptor() + // storagepolicy.DefaultSettings holds the default value on creation for the settings field. + storagepolicy.DefaultSettings = storagepolicyDescSettings.Default.(*types.PolicySetting) + taskMixin := schema.Task{}.Mixin() + taskMixinHooks0 := taskMixin[0].Hooks() + task.Hooks[0] = taskMixinHooks0[0] + taskMixinInters0 := taskMixin[0].Interceptors() + task.Interceptors[0] = taskMixinInters0[0] + taskMixinFields0 := taskMixin[0].Fields() + _ = taskMixinFields0 + taskFields := schema.Task{}.Fields() + _ = taskFields + // taskDescCreatedAt is the schema descriptor for created_at field. + taskDescCreatedAt := taskMixinFields0[0].Descriptor() + // task.DefaultCreatedAt holds the default value on creation for the created_at field. + task.DefaultCreatedAt = taskDescCreatedAt.Default.(func() time.Time) + // taskDescUpdatedAt is the schema descriptor for updated_at field. + taskDescUpdatedAt := taskMixinFields0[1].Descriptor() + // task.DefaultUpdatedAt holds the default value on creation for the updated_at field. + task.DefaultUpdatedAt = taskDescUpdatedAt.Default.(func() time.Time) + // task.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + task.UpdateDefaultUpdatedAt = taskDescUpdatedAt.UpdateDefault.(func() time.Time) + userMixin := schema.User{}.Mixin() + userMixinHooks0 := userMixin[0].Hooks() + user.Hooks[0] = userMixinHooks0[0] + userMixinInters0 := userMixin[0].Interceptors() + user.Interceptors[0] = userMixinInters0[0] + userMixinFields0 := userMixin[0].Fields() + _ = userMixinFields0 + userFields := schema.User{}.Fields() + _ = userFields + // userDescCreatedAt is the schema descriptor for created_at field. + userDescCreatedAt := userMixinFields0[0].Descriptor() + // user.DefaultCreatedAt holds the default value on creation for the created_at field. + user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) + // userDescUpdatedAt is the schema descriptor for updated_at field. + userDescUpdatedAt := userMixinFields0[1].Descriptor() + // user.DefaultUpdatedAt holds the default value on creation for the updated_at field. + user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time) + // user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + user.UpdateDefaultUpdatedAt = userDescUpdatedAt.UpdateDefault.(func() time.Time) + // userDescEmail is the schema descriptor for email field. + userDescEmail := userFields[0].Descriptor() + // user.EmailValidator is a validator for the "email" field. It is called by the builders before save. + user.EmailValidator = userDescEmail.Validators[0].(func(string) error) + // userDescNick is the schema descriptor for nick field. + userDescNick := userFields[1].Descriptor() + // user.NickValidator is a validator for the "nick" field. It is called by the builders before save. + user.NickValidator = userDescNick.Validators[0].(func(string) error) + // userDescStorage is the schema descriptor for storage field. + userDescStorage := userFields[4].Descriptor() + // user.DefaultStorage holds the default value on creation for the storage field. + user.DefaultStorage = userDescStorage.Default.(int64) + // userDescSettings is the schema descriptor for settings field. + userDescSettings := userFields[7].Descriptor() + // user.DefaultSettings holds the default value on creation for the settings field. + user.DefaultSettings = userDescSettings.Default.(*types.UserSetting) +} + +const ( + Version = "v0.13.0" // Version of ent codegen. + Sum = "h1:DclxWczaCpyiKn6ZWVcJjq1zIKtJ11iNKy+08lNYsJE=" // Sum of ent codegen. +) diff --git a/ent/schema/common.go b/ent/schema/common.go new file mode 100644 index 00000000..f78fef21 --- /dev/null +++ b/ent/schema/common.go @@ -0,0 +1,132 @@ +package schema + +import ( + "context" + "fmt" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/mixin" + gen "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/hook" + "github.com/cloudreve/Cloudreve/v4/ent/intercept" +) + +// CommonMixin implements the soft delete pattern for schemas and common audit features. +type CommonMixin struct { + mixin.Schema +} + +// Fields of the CommonMixin. +func (CommonMixin) Fields() []ent.Field { + return commonFields() +} + +type softDeleteKey struct{} + +// SkipSoftDelete returns a new context that skips the soft-delete interceptor/mutators. +func SkipSoftDelete(parent context.Context) context.Context { + return context.WithValue(parent, softDeleteKey{}, true) +} + +// Interceptors of the CommonMixin. +func (d CommonMixin) Interceptors() []ent.Interceptor { + return softDeleteInterceptors(d) +} + +// Hooks of the CommonMixin. +func (d CommonMixin) Hooks() []ent.Hook { + return commonHooks(d) +} + +// P adds a storage-level predicate to the queries and mutations. +func (d CommonMixin) P(w interface{ WhereP(...func(*sql.Selector)) }) { + p(d, w) +} + +// Indexes of the CommonMixin. +func (CommonMixin) Indexes() []ent.Index { + return []ent.Index{} +} + +func softDeleteInterceptors(d interface { + P(w interface { + WhereP(...func(*sql.Selector)) + }) +}) []ent.Interceptor { + return []ent.Interceptor{ + intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error { + // Skip soft-delete, means include soft-deleted entities. + if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip { + return nil + } + d.P(q) + return nil + }), + } +} + +func p(d interface{ Fields() []ent.Field }, w interface{ WhereP(...func(*sql.Selector)) }) { + w.WhereP( + sql.FieldIsNull(d.Fields()[2].Descriptor().Name), + ) +} + +func commonFields() []ent.Field { + return []ent.Field{ + field.Time("created_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{ + dialect.MySQL: "datetime", + }), + field.Time("updated_at"). + Default(time.Now). + UpdateDefault(time.Now). + SchemaType(map[string]string{ + dialect.MySQL: "datetime", + }), + field.Time("deleted_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{ + dialect.MySQL: "datetime", + }), + } +} + +func commonHooks(d interface { + P(w interface { + WhereP(...func(*sql.Selector)) + }) +}) []ent.Hook { + return []ent.Hook{ + hook.On( + func(next ent.Mutator) ent.Mutator { + return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { + // Skip soft-delete, means delete the entity permanently. + if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip { + return next.Mutate(ctx, m) + } + mx, ok := m.(interface { + SetOp(ent.Op) + Client() *gen.Client + SetDeletedAt(time.Time) + WhereP(...func(*sql.Selector)) + }) + if !ok { + return nil, fmt.Errorf("unexpected mutation type in soft-delete %T", m) + } + d.P(mx) + mx.SetOp(ent.OpUpdate) + mx.SetDeletedAt(time.Now()) + return mx.Client().Mutate(ctx, m) + }) + }, + ent.OpDeleteOne|ent.OpDelete, + ), + } +} diff --git a/ent/schema/davaccount.go b/ent/schema/davaccount.go new file mode 100644 index 00000000..f44b8089 --- /dev/null +++ b/ent/schema/davaccount.go @@ -0,0 +1,54 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// DavAccount holds the schema definition for the DavAccount entity. +type DavAccount struct { + ent.Schema +} + +// Fields of the DavAccount. +func (DavAccount) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + field.Text("uri"), + field.String("password"). + Sensitive(), + field.Bytes("options").GoType(&boolset.BooleanSet{}), + field.JSON("props", &types.DavAccountProps{}). + Optional(), + field.Int("owner_id"), + } +} + +// Edges of the DavAccount. +func (DavAccount) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("owner", User.Type). + Ref("dav_accounts"). + Field("owner_id"). + Unique(). + Required(), + } +} + +// Indexes of the DavAccount. +func (DavAccount) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("owner_id", "password"). + Unique(), + } +} + +func (DavAccount) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/directlink.go b/ent/schema/directlink.go new file mode 100644 index 00000000..91c7ea24 --- /dev/null +++ b/ent/schema/directlink.go @@ -0,0 +1,39 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" +) + +// DirectLink holds the schema definition for the DirectLink entity. +type DirectLink struct { + ent.Schema +} + +// Fields of the DirectLink. +func (DirectLink) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + field.Int("downloads"), + field.Int("file_id"), + field.Int("speed"), + } +} + +// Edges of the DirectLink. +func (DirectLink) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("file", File.Type). + Ref("direct_links"). + Field("file_id"). + Required(). + Unique(), + } +} + +func (DirectLink) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/entity.go b/ent/schema/entity.go new file mode 100644 index 00000000..a2b39993 --- /dev/null +++ b/ent/schema/entity.go @@ -0,0 +1,54 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// Entity holds the schema definition for the Entity. +type Entity struct { + ent.Schema +} + +// Fields of the Entity. +func (Entity) Fields() []ent.Field { + return []ent.Field{ + field.Int("type"), + field.Text("source"), + field.Int64("size"), + field.Int("reference_count").Default(1), + field.Int("storage_policy_entities"), + field.Int("created_by").Optional(), + field.UUID("upload_session_id", uuid.Must(uuid.NewV4())). + Optional(). + Nillable(), + field.JSON("recycle_options", &types.EntityRecycleOption{}). + Optional(), + } +} + +// Edges of the Entity. +func (Entity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("file", File.Type). + Ref("entities"), + edge.From("user", User.Type). + Field("created_by"). + Unique(). + Ref("entities"), + edge.From("storage_policy", StoragePolicy.Type). + Ref("entities"). + Field("storage_policy_entities"). + Unique(). + Required(), + } +} + +func (Entity) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/file.go b/ent/schema/file.go new file mode 100644 index 00000000..4961fe24 --- /dev/null +++ b/ent/schema/file.go @@ -0,0 +1,73 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// File holds the schema definition for the File entity. +type File struct { + ent.Schema +} + +// Fields of the File. +func (File) Fields() []ent.Field { + return []ent.Field{ + field.Int("type"), + field.String("name"), + field.Int("owner_id"), + field.Int64("size"). + Default(0), + field.Int("primary_entity"). + Optional(), + field.Int("file_children"). + Optional(), + field.Bool("is_symbolic"). + Default(false), + field.JSON("props", &types.FileProps{}).Optional(), + field.Int("storage_policy_files"). + Optional(), + } +} + +// Edges of the File. +func (File) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("owner", User.Type). + Ref("files"). + Field("owner_id"). + Unique(). + Required(), + edge.From("storage_policies", StoragePolicy.Type). + Ref("files"). + Field("storage_policy_files"). + Unique(), + edge.To("children", File.Type). + From("parent"). + Field("file_children"). + Unique(), + edge.To("metadata", Metadata.Type), + edge.To("entities", Entity.Type), + edge.To("shares", Share.Type), + edge.To("direct_links", DirectLink.Type), + } +} + +// Indexes of the File. +func (File) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("file_children", "name"). + Unique(), + index.Fields("file_children", "type", "updated_at"), + index.Fields("file_children", "type", "size"), + } +} + +func (File) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/group.go b/ent/schema/group.go new file mode 100644 index 00000000..ca1613b8 --- /dev/null +++ b/ent/schema/group.go @@ -0,0 +1,45 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// Group holds the schema definition for the Group entity. +type Group struct { + ent.Schema +} + +func (Group) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + field.Int64("max_storage"). + Optional(), + field.Int("speed_limit"). + Optional(), + field.Bytes("permissions").GoType(&boolset.BooleanSet{}), + field.JSON("settings", &types.GroupSetting{}). + Default(&types.GroupSetting{}). + Optional(), + field.Int("storage_policy_id").Optional(), + } +} + +func (Group) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} + +func (Group) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("users", User.Type), + edge.From("storage_policies", StoragePolicy.Type). + Ref("groups"). + Field("storage_policy_id"). + Unique(), + } +} diff --git a/ent/schema/metadata.go b/ent/schema/metadata.go new file mode 100644 index 00000000..3b17cdbc --- /dev/null +++ b/ent/schema/metadata.go @@ -0,0 +1,48 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// Metadata holds the schema definition for the Metadata entity. +type Metadata struct { + ent.Schema +} + +// Fields of the Metadata. +func (Metadata) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + field.Text("value"), + field.Int("file_id"), + field.Bool("is_public"). + Default(false), + } +} + +// Edges of the Metadata. +func (Metadata) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("file", File.Type). + Ref("metadata"). + Field("file_id"). + Required(). + Unique(), + } +} + +func (Metadata) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("file_id", "name"). + Unique(), + } +} + +func (Metadata) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/node.go b/ent/schema/node.go new file mode 100644 index 00000000..efce23ac --- /dev/null +++ b/ent/schema/node.go @@ -0,0 +1,46 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +// Node holds the schema definition for the Node entity. +type Node struct { + ent.Schema +} + +// Fields of the Node. +func (Node) Fields() []ent.Field { + return []ent.Field{ + field.Enum("status"). + Values("active", "suspended"), + field.String("name"), + field.Enum("type"). + Values("master", "slave"), + field.String("server"). + Optional(), + field.String("slave_key").Optional(), + field.Bytes("capabilities").GoType(&boolset.BooleanSet{}), + field.JSON("settings", &types.NodeSetting{}). + Default(&types.NodeSetting{}). + Optional(), + field.Int("weight").Default(0), + } +} + +// Edges of the Node. +func (Node) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("storage_policy", StoragePolicy.Type), + } +} + +func (Node) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/passkey.go b/ent/schema/passkey.go new file mode 100644 index 00000000..e2ba32ca --- /dev/null +++ b/ent/schema/passkey.go @@ -0,0 +1,55 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/go-webauthn/webauthn/webauthn" +) + +// Passkey holds the schema definition for the Passkey entity. +type Passkey struct { + ent.Schema +} + +// Fields of the Passkey. +func (Passkey) Fields() []ent.Field { + return []ent.Field{ + field.Int("user_id"), + field.String("credential_id"), + field.String("name"), + field.JSON("credential", &webauthn.Credential{}). + Sensitive(), + field.Time("used_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{ + dialect.MySQL: "datetime", + }), + } +} + +// Edges of the Passkey. +func (Passkey) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Field("user_id"). + Ref("passkey"). + Unique(). + Required(), + } +} + +func (Passkey) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} + +func (Passkey) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("user_id", "credential_id").Unique(), + } +} diff --git a/ent/schema/policy.go b/ent/schema/policy.go new file mode 100644 index 00000000..a89aabd7 --- /dev/null +++ b/ent/schema/policy.go @@ -0,0 +1,59 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// StoragePolicy holds the schema definition for the storage policy entity. +type StoragePolicy struct { + ent.Schema +} + +func (StoragePolicy) Fields() []ent.Field { + return []ent.Field{ + field.String("name"), + field.String("type"), + field.String("server"). + Optional(), + field.String("bucket_name"). + Optional(), + field.Bool("is_private"). + Optional(), + field.Text("access_key"). + Optional(), + field.Text("secret_key"). + Optional(), + field.Int64("max_size"). + Optional(), + field.String("dir_name_rule"). + Optional(), + field.String("file_name_rule"). + Optional(), + field.JSON("settings", &types.PolicySetting{}). + Default(&types.PolicySetting{}). + Optional(), + field.Int("node_id").Optional(), + } +} + +func (StoragePolicy) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} + +func (StoragePolicy) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("users", User.Type), + edge.To("groups", Group.Type), + edge.To("files", File.Type), + edge.To("entities", Entity.Type), + edge.From("node", Node.Type). + Ref("storage_policy"). + Field("node_id"). + Unique(), + } +} diff --git a/ent/schema/setting.go b/ent/schema/setting.go new file mode 100644 index 00000000..5b5f5eef --- /dev/null +++ b/ent/schema/setting.go @@ -0,0 +1,30 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" +) + +// Setting holds the schema definition for key-value setting entity. +type Setting struct { + ent.Schema +} + +func (Setting) Fields() []ent.Field { + return []ent.Field{ + field.String("name"). + Unique(), + field.Text("value"). + Optional(), + } +} + +func (Setting) Edges() []ent.Edge { + return nil +} + +func (Setting) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/share.go b/ent/schema/share.go new file mode 100644 index 00000000..b7d2ecee --- /dev/null +++ b/ent/schema/share.go @@ -0,0 +1,50 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" +) + +// Share holds the schema definition for the Share entity. +type Share struct { + ent.Schema +} + +// Fields of the Share. +func (Share) Fields() []ent.Field { + return []ent.Field{ + field.String("password"). + Optional(), + field.Int("views"). + Default(0), + field.Int("downloads"). + Default(0), + field.Time("expires"). + Nillable(). + Optional(). + SchemaType(map[string]string{ + dialect.MySQL: "datetime", + }), + field.Int("remain_downloads"). + Nillable(). + Optional(), + } +} + +// Edges of the Share. +func (Share) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("shares").Unique(), + edge.From("file", File.Type). + Ref("shares").Unique(), + } +} + +func (Share) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/task.go b/ent/schema/task.go new file mode 100644 index 00000000..ee82f561 --- /dev/null +++ b/ent/schema/task.go @@ -0,0 +1,46 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// Task holds the schema definition for the Task entity. +type Task struct { + ent.Schema +} + +// Fields of the Task. +func (Task) Fields() []ent.Field { + return []ent.Field{ + field.String("type"), + field.Enum("status"). + Values("queued", "processing", "suspending", "error", "canceled", "completed"). + Default("queued"), + field.JSON("public_state", &types.TaskPublicState{}), + field.Text("private_state").Optional(), + field.UUID("correlation_id", uuid.Must(uuid.NewV4())). + Optional(). + Immutable(), + field.Int("user_tasks").Optional(), + } +} + +// Edges of the Task. +func (Task) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("tasks"). + Field("user_tasks"). + Unique(), + } +} + +func (Task) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/schema/user.go b/ent/schema/user.go new file mode 100644 index 00000000..23cba9c7 --- /dev/null +++ b/ent/schema/user.go @@ -0,0 +1,62 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// User holds the schema definition for the User entity. +type User struct { + ent.Schema +} + +func (User) Fields() []ent.Field { + return []ent.Field{ + field.String("email"). + MaxLen(100). + Unique(), + field.String("nick"). + MaxLen(100), + field.String("password"). + Optional(). + Sensitive(), + field.Enum("status"). + Values("active", "inactive", "manual_banned", "sys_banned"). + Default("active"), + field.Int64("storage"). + Default(0), + field.String("two_factor_secret"). + Sensitive(). + Optional(), + field.String("avatar"). + Optional(), + field.JSON("settings", &types.UserSetting{}). + Default(&types.UserSetting{}). + Optional(), + field.Int("group_users"), + } +} + +func (User) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("group", Group.Type). + Ref("users"). + Field("group_users"). + Unique(). + Required(), + edge.To("files", File.Type), + edge.To("dav_accounts", DavAccount.Type), + edge.To("shares", Share.Type), + edge.To("passkey", Passkey.Type), + edge.To("tasks", Task.Type), + edge.To("entities", Entity.Type), + } +} + +func (User) Mixin() []ent.Mixin { + return []ent.Mixin{ + CommonMixin{}, + } +} diff --git a/ent/setting.go b/ent/setting.go new file mode 100644 index 00000000..9aee2cf9 --- /dev/null +++ b/ent/setting.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/setting" +) + +// Setting is the model entity for the Setting schema. +type Setting struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Setting) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case setting.FieldID: + values[i] = new(sql.NullInt64) + case setting.FieldName, setting.FieldValue: + values[i] = new(sql.NullString) + case setting.FieldCreatedAt, setting.FieldUpdatedAt, setting.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Setting fields. +func (s *Setting) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case setting.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + s.ID = int(value.Int64) + case setting.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + s.CreatedAt = value.Time + } + case setting.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + s.UpdatedAt = value.Time + } + case setting.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + s.DeletedAt = new(time.Time) + *s.DeletedAt = value.Time + } + case setting.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + s.Name = value.String + } + case setting.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + s.Value = value.String + } + default: + s.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the Setting. +// This includes values selected through modifiers, order, etc. +func (s *Setting) GetValue(name string) (ent.Value, error) { + return s.selectValues.Get(name) +} + +// Update returns a builder for updating this Setting. +// Note that you need to call Setting.Unwrap() before calling this method if this Setting +// was returned from a transaction, and the transaction was committed or rolled back. +func (s *Setting) Update() *SettingUpdateOne { + return NewSettingClient(s.config).UpdateOne(s) +} + +// Unwrap unwraps the Setting entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (s *Setting) Unwrap() *Setting { + _tx, ok := s.config.driver.(*txDriver) + if !ok { + panic("ent: Setting is not a transactional entity") + } + s.config.driver = _tx.drv + return s +} + +// String implements the fmt.Stringer. +func (s *Setting) String() string { + var builder strings.Builder + builder.WriteString("Setting(") + builder.WriteString(fmt.Sprintf("id=%v, ", s.ID)) + builder.WriteString("created_at=") + builder.WriteString(s.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(s.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := s.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(s.Name) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(s.Value) + builder.WriteByte(')') + return builder.String() +} + +// Settings is a parsable slice of Setting. +type Settings []*Setting diff --git a/ent/setting/setting.go b/ent/setting/setting.go new file mode 100644 index 00000000..38421871 --- /dev/null +++ b/ent/setting/setting.go @@ -0,0 +1,98 @@ +// Code generated by ent, DO NOT EDIT. + +package setting + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the setting type in the database. + Label = "setting" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // Table holds the table name of the setting in the database. + Table = "settings" +) + +// Columns holds all SQL columns for setting fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldValue, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the Setting queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/ent/setting/where.go b/ent/setting/where.go new file mode 100644 index 00000000..bf227aeb --- /dev/null +++ b/ent/setting/where.go @@ -0,0 +1,365 @@ +// Code generated by ent, DO NOT EDIT. + +package setting + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldName, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldValue, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Setting { + return predicate.Setting(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Setting { + return predicate.Setting(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.Setting { + return predicate.Setting(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.Setting { + return predicate.Setting(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.Setting { + return predicate.Setting(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.Setting { + return predicate.Setting(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.Setting { + return predicate.Setting(sql.FieldContainsFold(FieldName, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.Setting { + return predicate.Setting(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.Setting { + return predicate.Setting(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.Setting { + return predicate.Setting(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.Setting { + return predicate.Setting(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.Setting { + return predicate.Setting(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.Setting { + return predicate.Setting(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.Setting { + return predicate.Setting(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.Setting { + return predicate.Setting(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.Setting { + return predicate.Setting(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.Setting { + return predicate.Setting(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.Setting { + return predicate.Setting(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueIsNil applies the IsNil predicate on the "value" field. +func ValueIsNil() predicate.Setting { + return predicate.Setting(sql.FieldIsNull(FieldValue)) +} + +// ValueNotNil applies the NotNil predicate on the "value" field. +func ValueNotNil() predicate.Setting { + return predicate.Setting(sql.FieldNotNull(FieldValue)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.Setting { + return predicate.Setting(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.Setting { + return predicate.Setting(sql.FieldContainsFold(FieldValue, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Setting) predicate.Setting { + return predicate.Setting(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Setting) predicate.Setting { + return predicate.Setting(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Setting) predicate.Setting { + return predicate.Setting(sql.NotPredicates(p)) +} diff --git a/ent/setting_create.go b/ent/setting_create.go new file mode 100644 index 00000000..9dab82f2 --- /dev/null +++ b/ent/setting_create.go @@ -0,0 +1,740 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/setting" +) + +// SettingCreate is the builder for creating a Setting entity. +type SettingCreate struct { + config + mutation *SettingMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (sc *SettingCreate) SetCreatedAt(t time.Time) *SettingCreate { + sc.mutation.SetCreatedAt(t) + return sc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (sc *SettingCreate) SetNillableCreatedAt(t *time.Time) *SettingCreate { + if t != nil { + sc.SetCreatedAt(*t) + } + return sc +} + +// SetUpdatedAt sets the "updated_at" field. +func (sc *SettingCreate) SetUpdatedAt(t time.Time) *SettingCreate { + sc.mutation.SetUpdatedAt(t) + return sc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (sc *SettingCreate) SetNillableUpdatedAt(t *time.Time) *SettingCreate { + if t != nil { + sc.SetUpdatedAt(*t) + } + return sc +} + +// SetDeletedAt sets the "deleted_at" field. +func (sc *SettingCreate) SetDeletedAt(t time.Time) *SettingCreate { + sc.mutation.SetDeletedAt(t) + return sc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (sc *SettingCreate) SetNillableDeletedAt(t *time.Time) *SettingCreate { + if t != nil { + sc.SetDeletedAt(*t) + } + return sc +} + +// SetName sets the "name" field. +func (sc *SettingCreate) SetName(s string) *SettingCreate { + sc.mutation.SetName(s) + return sc +} + +// SetValue sets the "value" field. +func (sc *SettingCreate) SetValue(s string) *SettingCreate { + sc.mutation.SetValue(s) + return sc +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (sc *SettingCreate) SetNillableValue(s *string) *SettingCreate { + if s != nil { + sc.SetValue(*s) + } + return sc +} + +// Mutation returns the SettingMutation object of the builder. +func (sc *SettingCreate) Mutation() *SettingMutation { + return sc.mutation +} + +// Save creates the Setting in the database. +func (sc *SettingCreate) Save(ctx context.Context) (*Setting, error) { + if err := sc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, sc.sqlSave, sc.mutation, sc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (sc *SettingCreate) SaveX(ctx context.Context) *Setting { + v, err := sc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (sc *SettingCreate) Exec(ctx context.Context) error { + _, err := sc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (sc *SettingCreate) ExecX(ctx context.Context) { + if err := sc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (sc *SettingCreate) defaults() error { + if _, ok := sc.mutation.CreatedAt(); !ok { + if setting.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized setting.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := setting.DefaultCreatedAt() + sc.mutation.SetCreatedAt(v) + } + if _, ok := sc.mutation.UpdatedAt(); !ok { + if setting.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized setting.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := setting.DefaultUpdatedAt() + sc.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (sc *SettingCreate) check() error { + if _, ok := sc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Setting.created_at"`)} + } + if _, ok := sc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Setting.updated_at"`)} + } + if _, ok := sc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Setting.name"`)} + } + return nil +} + +func (sc *SettingCreate) sqlSave(ctx context.Context) (*Setting, error) { + if err := sc.check(); err != nil { + return nil, err + } + _node, _spec := sc.createSpec() + if err := sqlgraph.CreateNode(ctx, sc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + sc.mutation.id = &_node.ID + sc.mutation.done = true + return _node, nil +} + +func (sc *SettingCreate) createSpec() (*Setting, *sqlgraph.CreateSpec) { + var ( + _node = &Setting{config: sc.config} + _spec = sqlgraph.NewCreateSpec(setting.Table, sqlgraph.NewFieldSpec(setting.FieldID, field.TypeInt)) + ) + + if id, ok := sc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = sc.conflict + if value, ok := sc.mutation.CreatedAt(); ok { + _spec.SetField(setting.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := sc.mutation.UpdatedAt(); ok { + _spec.SetField(setting.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := sc.mutation.DeletedAt(); ok { + _spec.SetField(setting.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := sc.mutation.Name(); ok { + _spec.SetField(setting.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := sc.mutation.Value(); ok { + _spec.SetField(setting.FieldValue, field.TypeString, value) + _node.Value = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Setting.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SettingUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (sc *SettingCreate) OnConflict(opts ...sql.ConflictOption) *SettingUpsertOne { + sc.conflict = opts + return &SettingUpsertOne{ + create: sc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (sc *SettingCreate) OnConflictColumns(columns ...string) *SettingUpsertOne { + sc.conflict = append(sc.conflict, sql.ConflictColumns(columns...)) + return &SettingUpsertOne{ + create: sc, + } +} + +type ( + // SettingUpsertOne is the builder for "upsert"-ing + // one Setting node. + SettingUpsertOne struct { + create *SettingCreate + } + + // SettingUpsert is the "OnConflict" setter. + SettingUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SettingUpsert) SetUpdatedAt(v time.Time) *SettingUpsert { + u.Set(setting.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SettingUpsert) UpdateUpdatedAt() *SettingUpsert { + u.SetExcluded(setting.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *SettingUpsert) SetDeletedAt(v time.Time) *SettingUpsert { + u.Set(setting.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *SettingUpsert) UpdateDeletedAt() *SettingUpsert { + u.SetExcluded(setting.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *SettingUpsert) ClearDeletedAt() *SettingUpsert { + u.SetNull(setting.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *SettingUpsert) SetName(v string) *SettingUpsert { + u.Set(setting.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *SettingUpsert) UpdateName() *SettingUpsert { + u.SetExcluded(setting.FieldName) + return u +} + +// SetValue sets the "value" field. +func (u *SettingUpsert) SetValue(v string) *SettingUpsert { + u.Set(setting.FieldValue, v) + return u +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SettingUpsert) UpdateValue() *SettingUpsert { + u.SetExcluded(setting.FieldValue) + return u +} + +// ClearValue clears the value of the "value" field. +func (u *SettingUpsert) ClearValue() *SettingUpsert { + u.SetNull(setting.FieldValue) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SettingUpsertOne) UpdateNewValues() *SettingUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(setting.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SettingUpsertOne) Ignore() *SettingUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SettingUpsertOne) DoNothing() *SettingUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SettingCreate.OnConflict +// documentation for more info. +func (u *SettingUpsertOne) Update(set func(*SettingUpsert)) *SettingUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SettingUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SettingUpsertOne) SetUpdatedAt(v time.Time) *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SettingUpsertOne) UpdateUpdatedAt() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *SettingUpsertOne) SetDeletedAt(v time.Time) *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *SettingUpsertOne) UpdateDeletedAt() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *SettingUpsertOne) ClearDeletedAt() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *SettingUpsertOne) SetName(v string) *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *SettingUpsertOne) UpdateName() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.UpdateName() + }) +} + +// SetValue sets the "value" field. +func (u *SettingUpsertOne) SetValue(v string) *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SettingUpsertOne) UpdateValue() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.UpdateValue() + }) +} + +// ClearValue clears the value of the "value" field. +func (u *SettingUpsertOne) ClearValue() *SettingUpsertOne { + return u.Update(func(s *SettingUpsert) { + s.ClearValue() + }) +} + +// Exec executes the query. +func (u *SettingUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SettingCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SettingUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SettingUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SettingUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *SettingCreate) SetRawID(t int) *SettingCreate { + m.mutation.SetRawID(t) + return m +} + +// SettingCreateBulk is the builder for creating many Setting entities in bulk. +type SettingCreateBulk struct { + config + err error + builders []*SettingCreate + conflict []sql.ConflictOption +} + +// Save creates the Setting entities in the database. +func (scb *SettingCreateBulk) Save(ctx context.Context) ([]*Setting, error) { + if scb.err != nil { + return nil, scb.err + } + specs := make([]*sqlgraph.CreateSpec, len(scb.builders)) + nodes := make([]*Setting, len(scb.builders)) + mutators := make([]Mutator, len(scb.builders)) + for i := range scb.builders { + func(i int, root context.Context) { + builder := scb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SettingMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, scb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = scb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, scb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, scb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (scb *SettingCreateBulk) SaveX(ctx context.Context) []*Setting { + v, err := scb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (scb *SettingCreateBulk) Exec(ctx context.Context) error { + _, err := scb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (scb *SettingCreateBulk) ExecX(ctx context.Context) { + if err := scb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Setting.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SettingUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (scb *SettingCreateBulk) OnConflict(opts ...sql.ConflictOption) *SettingUpsertBulk { + scb.conflict = opts + return &SettingUpsertBulk{ + create: scb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (scb *SettingCreateBulk) OnConflictColumns(columns ...string) *SettingUpsertBulk { + scb.conflict = append(scb.conflict, sql.ConflictColumns(columns...)) + return &SettingUpsertBulk{ + create: scb, + } +} + +// SettingUpsertBulk is the builder for "upsert"-ing +// a bulk of Setting nodes. +type SettingUpsertBulk struct { + create *SettingCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SettingUpsertBulk) UpdateNewValues() *SettingUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(setting.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Setting.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SettingUpsertBulk) Ignore() *SettingUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SettingUpsertBulk) DoNothing() *SettingUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SettingCreateBulk.OnConflict +// documentation for more info. +func (u *SettingUpsertBulk) Update(set func(*SettingUpsert)) *SettingUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SettingUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SettingUpsertBulk) SetUpdatedAt(v time.Time) *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SettingUpsertBulk) UpdateUpdatedAt() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *SettingUpsertBulk) SetDeletedAt(v time.Time) *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *SettingUpsertBulk) UpdateDeletedAt() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *SettingUpsertBulk) ClearDeletedAt() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *SettingUpsertBulk) SetName(v string) *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *SettingUpsertBulk) UpdateName() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.UpdateName() + }) +} + +// SetValue sets the "value" field. +func (u *SettingUpsertBulk) SetValue(v string) *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SettingUpsertBulk) UpdateValue() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.UpdateValue() + }) +} + +// ClearValue clears the value of the "value" field. +func (u *SettingUpsertBulk) ClearValue() *SettingUpsertBulk { + return u.Update(func(s *SettingUpsert) { + s.ClearValue() + }) +} + +// Exec executes the query. +func (u *SettingUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SettingCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SettingCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SettingUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/setting_delete.go b/ent/setting_delete.go new file mode 100644 index 00000000..b3cb0aa1 --- /dev/null +++ b/ent/setting_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/setting" +) + +// SettingDelete is the builder for deleting a Setting entity. +type SettingDelete struct { + config + hooks []Hook + mutation *SettingMutation +} + +// Where appends a list predicates to the SettingDelete builder. +func (sd *SettingDelete) Where(ps ...predicate.Setting) *SettingDelete { + sd.mutation.Where(ps...) + return sd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (sd *SettingDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, sd.sqlExec, sd.mutation, sd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (sd *SettingDelete) ExecX(ctx context.Context) int { + n, err := sd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (sd *SettingDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(setting.Table, sqlgraph.NewFieldSpec(setting.FieldID, field.TypeInt)) + if ps := sd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, sd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + sd.mutation.done = true + return affected, err +} + +// SettingDeleteOne is the builder for deleting a single Setting entity. +type SettingDeleteOne struct { + sd *SettingDelete +} + +// Where appends a list predicates to the SettingDelete builder. +func (sdo *SettingDeleteOne) Where(ps ...predicate.Setting) *SettingDeleteOne { + sdo.sd.mutation.Where(ps...) + return sdo +} + +// Exec executes the deletion query. +func (sdo *SettingDeleteOne) Exec(ctx context.Context) error { + n, err := sdo.sd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{setting.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (sdo *SettingDeleteOne) ExecX(ctx context.Context) { + if err := sdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/setting_query.go b/ent/setting_query.go new file mode 100644 index 00000000..cf8457f3 --- /dev/null +++ b/ent/setting_query.go @@ -0,0 +1,526 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/setting" +) + +// SettingQuery is the builder for querying Setting entities. +type SettingQuery struct { + config + ctx *QueryContext + order []setting.OrderOption + inters []Interceptor + predicates []predicate.Setting + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SettingQuery builder. +func (sq *SettingQuery) Where(ps ...predicate.Setting) *SettingQuery { + sq.predicates = append(sq.predicates, ps...) + return sq +} + +// Limit the number of records to be returned by this query. +func (sq *SettingQuery) Limit(limit int) *SettingQuery { + sq.ctx.Limit = &limit + return sq +} + +// Offset to start from. +func (sq *SettingQuery) Offset(offset int) *SettingQuery { + sq.ctx.Offset = &offset + return sq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (sq *SettingQuery) Unique(unique bool) *SettingQuery { + sq.ctx.Unique = &unique + return sq +} + +// Order specifies how the records should be ordered. +func (sq *SettingQuery) Order(o ...setting.OrderOption) *SettingQuery { + sq.order = append(sq.order, o...) + return sq +} + +// First returns the first Setting entity from the query. +// Returns a *NotFoundError when no Setting was found. +func (sq *SettingQuery) First(ctx context.Context) (*Setting, error) { + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{setting.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (sq *SettingQuery) FirstX(ctx context.Context) *Setting { + node, err := sq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Setting ID from the query. +// Returns a *NotFoundError when no Setting ID was found. +func (sq *SettingQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{setting.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (sq *SettingQuery) FirstIDX(ctx context.Context) int { + id, err := sq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Setting entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Setting entity is found. +// Returns a *NotFoundError when no Setting entities are found. +func (sq *SettingQuery) Only(ctx context.Context) (*Setting, error) { + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{setting.Label} + default: + return nil, &NotSingularError{setting.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (sq *SettingQuery) OnlyX(ctx context.Context) *Setting { + node, err := sq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Setting ID in the query. +// Returns a *NotSingularError when more than one Setting ID is found. +// Returns a *NotFoundError when no entities are found. +func (sq *SettingQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{setting.Label} + default: + err = &NotSingularError{setting.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (sq *SettingQuery) OnlyIDX(ctx context.Context) int { + id, err := sq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Settings. +func (sq *SettingQuery) All(ctx context.Context) ([]*Setting, error) { + ctx = setContextOp(ctx, sq.ctx, "All") + if err := sq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Setting, *SettingQuery]() + return withInterceptors[[]*Setting](ctx, sq, qr, sq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (sq *SettingQuery) AllX(ctx context.Context) []*Setting { + nodes, err := sq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Setting IDs. +func (sq *SettingQuery) IDs(ctx context.Context) (ids []int, err error) { + if sq.ctx.Unique == nil && sq.path != nil { + sq.Unique(true) + } + ctx = setContextOp(ctx, sq.ctx, "IDs") + if err = sq.Select(setting.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (sq *SettingQuery) IDsX(ctx context.Context) []int { + ids, err := sq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (sq *SettingQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sq.ctx, "Count") + if err := sq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, sq, querierCount[*SettingQuery](), sq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (sq *SettingQuery) CountX(ctx context.Context) int { + count, err := sq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (sq *SettingQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, sq.ctx, "Exist") + switch _, err := sq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (sq *SettingQuery) ExistX(ctx context.Context) bool { + exist, err := sq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SettingQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (sq *SettingQuery) Clone() *SettingQuery { + if sq == nil { + return nil + } + return &SettingQuery{ + config: sq.config, + ctx: sq.ctx.Clone(), + order: append([]setting.OrderOption{}, sq.order...), + inters: append([]Interceptor{}, sq.inters...), + predicates: append([]predicate.Setting{}, sq.predicates...), + // clone intermediate query. + sql: sq.sql.Clone(), + path: sq.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Setting.Query(). +// GroupBy(setting.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (sq *SettingQuery) GroupBy(field string, fields ...string) *SettingGroupBy { + sq.ctx.Fields = append([]string{field}, fields...) + grbuild := &SettingGroupBy{build: sq} + grbuild.flds = &sq.ctx.Fields + grbuild.label = setting.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Setting.Query(). +// Select(setting.FieldCreatedAt). +// Scan(ctx, &v) +func (sq *SettingQuery) Select(fields ...string) *SettingSelect { + sq.ctx.Fields = append(sq.ctx.Fields, fields...) + sbuild := &SettingSelect{SettingQuery: sq} + sbuild.label = setting.Label + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SettingSelect configured with the given aggregations. +func (sq *SettingQuery) Aggregate(fns ...AggregateFunc) *SettingSelect { + return sq.Select().Aggregate(fns...) +} + +func (sq *SettingQuery) prepareQuery(ctx context.Context) error { + for _, inter := range sq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sq); err != nil { + return err + } + } + } + for _, f := range sq.ctx.Fields { + if !setting.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if sq.path != nil { + prev, err := sq.path(ctx) + if err != nil { + return err + } + sq.sql = prev + } + return nil +} + +func (sq *SettingQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Setting, error) { + var ( + nodes = []*Setting{} + _spec = sq.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Setting).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Setting{config: sq.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, sq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (sq *SettingQuery) sqlCount(ctx context.Context) (int, error) { + _spec := sq.querySpec() + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, sq.driver, _spec) +} + +func (sq *SettingQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(setting.Table, setting.Columns, sqlgraph.NewFieldSpec(setting.FieldID, field.TypeInt)) + _spec.From = sq.sql + if unique := sq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if sq.path != nil { + _spec.Unique = true + } + if fields := sq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, setting.FieldID) + for i := range fields { + if fields[i] != setting.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := sq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := sq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := sq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := sq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (sq *SettingQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(sq.driver.Dialect()) + t1 := builder.Table(setting.Table) + columns := sq.ctx.Fields + if len(columns) == 0 { + columns = setting.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if sq.sql != nil { + selector = sq.sql + selector.Select(selector.Columns(columns...)...) + } + if sq.ctx.Unique != nil && *sq.ctx.Unique { + selector.Distinct() + } + for _, p := range sq.predicates { + p(selector) + } + for _, p := range sq.order { + p(selector) + } + if offset := sq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := sq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// SettingGroupBy is the group-by builder for Setting entities. +type SettingGroupBy struct { + selector + build *SettingQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (sgb *SettingGroupBy) Aggregate(fns ...AggregateFunc) *SettingGroupBy { + sgb.fns = append(sgb.fns, fns...) + return sgb +} + +// Scan applies the selector query and scans the result into the given value. +func (sgb *SettingGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") + if err := sgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SettingQuery, *SettingGroupBy](ctx, sgb.build, sgb, sgb.build.inters, v) +} + +func (sgb *SettingGroupBy) sqlScan(ctx context.Context, root *SettingQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sgb.fns)) + for _, fn := range sgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sgb.flds)+len(sgb.fns)) + for _, f := range *sgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*sgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := sgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SettingSelect is the builder for selecting fields of Setting entities. +type SettingSelect struct { + *SettingQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ss *SettingSelect) Aggregate(fns ...AggregateFunc) *SettingSelect { + ss.fns = append(ss.fns, fns...) + return ss +} + +// Scan applies the selector query and scans the result into the given value. +func (ss *SettingSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ss.ctx, "Select") + if err := ss.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SettingQuery, *SettingSelect](ctx, ss.SettingQuery, ss, ss.inters, v) +} + +func (ss *SettingSelect) sqlScan(ctx context.Context, root *SettingQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ss.fns)) + for _, fn := range ss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ss.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/setting_update.go b/ent/setting_update.go new file mode 100644 index 00000000..9882b8d5 --- /dev/null +++ b/ent/setting_update.go @@ -0,0 +1,362 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/setting" +) + +// SettingUpdate is the builder for updating Setting entities. +type SettingUpdate struct { + config + hooks []Hook + mutation *SettingMutation +} + +// Where appends a list predicates to the SettingUpdate builder. +func (su *SettingUpdate) Where(ps ...predicate.Setting) *SettingUpdate { + su.mutation.Where(ps...) + return su +} + +// SetUpdatedAt sets the "updated_at" field. +func (su *SettingUpdate) SetUpdatedAt(t time.Time) *SettingUpdate { + su.mutation.SetUpdatedAt(t) + return su +} + +// SetDeletedAt sets the "deleted_at" field. +func (su *SettingUpdate) SetDeletedAt(t time.Time) *SettingUpdate { + su.mutation.SetDeletedAt(t) + return su +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (su *SettingUpdate) SetNillableDeletedAt(t *time.Time) *SettingUpdate { + if t != nil { + su.SetDeletedAt(*t) + } + return su +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (su *SettingUpdate) ClearDeletedAt() *SettingUpdate { + su.mutation.ClearDeletedAt() + return su +} + +// SetName sets the "name" field. +func (su *SettingUpdate) SetName(s string) *SettingUpdate { + su.mutation.SetName(s) + return su +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (su *SettingUpdate) SetNillableName(s *string) *SettingUpdate { + if s != nil { + su.SetName(*s) + } + return su +} + +// SetValue sets the "value" field. +func (su *SettingUpdate) SetValue(s string) *SettingUpdate { + su.mutation.SetValue(s) + return su +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (su *SettingUpdate) SetNillableValue(s *string) *SettingUpdate { + if s != nil { + su.SetValue(*s) + } + return su +} + +// ClearValue clears the value of the "value" field. +func (su *SettingUpdate) ClearValue() *SettingUpdate { + su.mutation.ClearValue() + return su +} + +// Mutation returns the SettingMutation object of the builder. +func (su *SettingUpdate) Mutation() *SettingMutation { + return su.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (su *SettingUpdate) Save(ctx context.Context) (int, error) { + if err := su.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, su.sqlSave, su.mutation, su.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (su *SettingUpdate) SaveX(ctx context.Context) int { + affected, err := su.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (su *SettingUpdate) Exec(ctx context.Context) error { + _, err := su.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (su *SettingUpdate) ExecX(ctx context.Context) { + if err := su.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (su *SettingUpdate) defaults() error { + if _, ok := su.mutation.UpdatedAt(); !ok { + if setting.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized setting.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := setting.UpdateDefaultUpdatedAt() + su.mutation.SetUpdatedAt(v) + } + return nil +} + +func (su *SettingUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(setting.Table, setting.Columns, sqlgraph.NewFieldSpec(setting.FieldID, field.TypeInt)) + if ps := su.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := su.mutation.UpdatedAt(); ok { + _spec.SetField(setting.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := su.mutation.DeletedAt(); ok { + _spec.SetField(setting.FieldDeletedAt, field.TypeTime, value) + } + if su.mutation.DeletedAtCleared() { + _spec.ClearField(setting.FieldDeletedAt, field.TypeTime) + } + if value, ok := su.mutation.Name(); ok { + _spec.SetField(setting.FieldName, field.TypeString, value) + } + if value, ok := su.mutation.Value(); ok { + _spec.SetField(setting.FieldValue, field.TypeString, value) + } + if su.mutation.ValueCleared() { + _spec.ClearField(setting.FieldValue, field.TypeString) + } + if n, err = sqlgraph.UpdateNodes(ctx, su.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{setting.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + su.mutation.done = true + return n, nil +} + +// SettingUpdateOne is the builder for updating a single Setting entity. +type SettingUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SettingMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (suo *SettingUpdateOne) SetUpdatedAt(t time.Time) *SettingUpdateOne { + suo.mutation.SetUpdatedAt(t) + return suo +} + +// SetDeletedAt sets the "deleted_at" field. +func (suo *SettingUpdateOne) SetDeletedAt(t time.Time) *SettingUpdateOne { + suo.mutation.SetDeletedAt(t) + return suo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (suo *SettingUpdateOne) SetNillableDeletedAt(t *time.Time) *SettingUpdateOne { + if t != nil { + suo.SetDeletedAt(*t) + } + return suo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (suo *SettingUpdateOne) ClearDeletedAt() *SettingUpdateOne { + suo.mutation.ClearDeletedAt() + return suo +} + +// SetName sets the "name" field. +func (suo *SettingUpdateOne) SetName(s string) *SettingUpdateOne { + suo.mutation.SetName(s) + return suo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (suo *SettingUpdateOne) SetNillableName(s *string) *SettingUpdateOne { + if s != nil { + suo.SetName(*s) + } + return suo +} + +// SetValue sets the "value" field. +func (suo *SettingUpdateOne) SetValue(s string) *SettingUpdateOne { + suo.mutation.SetValue(s) + return suo +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (suo *SettingUpdateOne) SetNillableValue(s *string) *SettingUpdateOne { + if s != nil { + suo.SetValue(*s) + } + return suo +} + +// ClearValue clears the value of the "value" field. +func (suo *SettingUpdateOne) ClearValue() *SettingUpdateOne { + suo.mutation.ClearValue() + return suo +} + +// Mutation returns the SettingMutation object of the builder. +func (suo *SettingUpdateOne) Mutation() *SettingMutation { + return suo.mutation +} + +// Where appends a list predicates to the SettingUpdate builder. +func (suo *SettingUpdateOne) Where(ps ...predicate.Setting) *SettingUpdateOne { + suo.mutation.Where(ps...) + return suo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (suo *SettingUpdateOne) Select(field string, fields ...string) *SettingUpdateOne { + suo.fields = append([]string{field}, fields...) + return suo +} + +// Save executes the query and returns the updated Setting entity. +func (suo *SettingUpdateOne) Save(ctx context.Context) (*Setting, error) { + if err := suo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, suo.sqlSave, suo.mutation, suo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (suo *SettingUpdateOne) SaveX(ctx context.Context) *Setting { + node, err := suo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (suo *SettingUpdateOne) Exec(ctx context.Context) error { + _, err := suo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (suo *SettingUpdateOne) ExecX(ctx context.Context) { + if err := suo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (suo *SettingUpdateOne) defaults() error { + if _, ok := suo.mutation.UpdatedAt(); !ok { + if setting.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized setting.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := setting.UpdateDefaultUpdatedAt() + suo.mutation.SetUpdatedAt(v) + } + return nil +} + +func (suo *SettingUpdateOne) sqlSave(ctx context.Context) (_node *Setting, err error) { + _spec := sqlgraph.NewUpdateSpec(setting.Table, setting.Columns, sqlgraph.NewFieldSpec(setting.FieldID, field.TypeInt)) + id, ok := suo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Setting.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := suo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, setting.FieldID) + for _, f := range fields { + if !setting.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != setting.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := suo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := suo.mutation.UpdatedAt(); ok { + _spec.SetField(setting.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := suo.mutation.DeletedAt(); ok { + _spec.SetField(setting.FieldDeletedAt, field.TypeTime, value) + } + if suo.mutation.DeletedAtCleared() { + _spec.ClearField(setting.FieldDeletedAt, field.TypeTime) + } + if value, ok := suo.mutation.Name(); ok { + _spec.SetField(setting.FieldName, field.TypeString, value) + } + if value, ok := suo.mutation.Value(); ok { + _spec.SetField(setting.FieldValue, field.TypeString, value) + } + if suo.mutation.ValueCleared() { + _spec.ClearField(setting.FieldValue, field.TypeString) + } + _node = &Setting{config: suo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, suo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{setting.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + suo.mutation.done = true + return _node, nil +} diff --git a/ent/share.go b/ent/share.go new file mode 100644 index 00000000..b998e904 --- /dev/null +++ b/ent/share.go @@ -0,0 +1,276 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// Share is the model entity for the Share schema. +type Share struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Password holds the value of the "password" field. + Password string `json:"password,omitempty"` + // Views holds the value of the "views" field. + Views int `json:"views,omitempty"` + // Downloads holds the value of the "downloads" field. + Downloads int `json:"downloads,omitempty"` + // Expires holds the value of the "expires" field. + Expires *time.Time `json:"expires,omitempty"` + // RemainDownloads holds the value of the "remain_downloads" field. + RemainDownloads *int `json:"remain_downloads,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ShareQuery when eager-loading is set. + Edges ShareEdges `json:"edges"` + file_shares *int + user_shares *int + selectValues sql.SelectValues +} + +// ShareEdges holds the relations/edges for other nodes in the graph. +type ShareEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // File holds the value of the file edge. + File *File `json:"file,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e ShareEdges) UserOrErr() (*User, error) { + if e.loadedTypes[0] { + if e.User == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.User, nil + } + return nil, &NotLoadedError{edge: "user"} +} + +// FileOrErr returns the File value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e ShareEdges) FileOrErr() (*File, error) { + if e.loadedTypes[1] { + if e.File == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: file.Label} + } + return e.File, nil + } + return nil, &NotLoadedError{edge: "file"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Share) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case share.FieldID, share.FieldViews, share.FieldDownloads, share.FieldRemainDownloads: + values[i] = new(sql.NullInt64) + case share.FieldPassword: + values[i] = new(sql.NullString) + case share.FieldCreatedAt, share.FieldUpdatedAt, share.FieldDeletedAt, share.FieldExpires: + values[i] = new(sql.NullTime) + case share.ForeignKeys[0]: // file_shares + values[i] = new(sql.NullInt64) + case share.ForeignKeys[1]: // user_shares + values[i] = new(sql.NullInt64) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Share fields. +func (s *Share) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case share.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + s.ID = int(value.Int64) + case share.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + s.CreatedAt = value.Time + } + case share.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + s.UpdatedAt = value.Time + } + case share.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + s.DeletedAt = new(time.Time) + *s.DeletedAt = value.Time + } + case share.FieldPassword: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field password", values[i]) + } else if value.Valid { + s.Password = value.String + } + case share.FieldViews: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field views", values[i]) + } else if value.Valid { + s.Views = int(value.Int64) + } + case share.FieldDownloads: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field downloads", values[i]) + } else if value.Valid { + s.Downloads = int(value.Int64) + } + case share.FieldExpires: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires", values[i]) + } else if value.Valid { + s.Expires = new(time.Time) + *s.Expires = value.Time + } + case share.FieldRemainDownloads: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field remain_downloads", values[i]) + } else if value.Valid { + s.RemainDownloads = new(int) + *s.RemainDownloads = int(value.Int64) + } + case share.ForeignKeys[0]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field file_shares", value) + } else if value.Valid { + s.file_shares = new(int) + *s.file_shares = int(value.Int64) + } + case share.ForeignKeys[1]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field user_shares", value) + } else if value.Valid { + s.user_shares = new(int) + *s.user_shares = int(value.Int64) + } + default: + s.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Share. +// This includes values selected through modifiers, order, etc. +func (s *Share) Value(name string) (ent.Value, error) { + return s.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the Share entity. +func (s *Share) QueryUser() *UserQuery { + return NewShareClient(s.config).QueryUser(s) +} + +// QueryFile queries the "file" edge of the Share entity. +func (s *Share) QueryFile() *FileQuery { + return NewShareClient(s.config).QueryFile(s) +} + +// Update returns a builder for updating this Share. +// Note that you need to call Share.Unwrap() before calling this method if this Share +// was returned from a transaction, and the transaction was committed or rolled back. +func (s *Share) Update() *ShareUpdateOne { + return NewShareClient(s.config).UpdateOne(s) +} + +// Unwrap unwraps the Share entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (s *Share) Unwrap() *Share { + _tx, ok := s.config.driver.(*txDriver) + if !ok { + panic("ent: Share is not a transactional entity") + } + s.config.driver = _tx.drv + return s +} + +// String implements the fmt.Stringer. +func (s *Share) String() string { + var builder strings.Builder + builder.WriteString("Share(") + builder.WriteString(fmt.Sprintf("id=%v, ", s.ID)) + builder.WriteString("created_at=") + builder.WriteString(s.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(s.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := s.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("password=") + builder.WriteString(s.Password) + builder.WriteString(", ") + builder.WriteString("views=") + builder.WriteString(fmt.Sprintf("%v", s.Views)) + builder.WriteString(", ") + builder.WriteString("downloads=") + builder.WriteString(fmt.Sprintf("%v", s.Downloads)) + builder.WriteString(", ") + if v := s.Expires; v != nil { + builder.WriteString("expires=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := s.RemainDownloads; v != nil { + builder.WriteString("remain_downloads=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteByte(')') + return builder.String() +} + +// SetUser manually set the edge as loaded state. +func (e *Share) SetUser(v *User) { + e.Edges.User = v + e.Edges.loadedTypes[0] = true +} + +// SetFile manually set the edge as loaded state. +func (e *Share) SetFile(v *File) { + e.Edges.File = v + e.Edges.loadedTypes[1] = true +} + +// Shares is a parsable slice of Share. +type Shares []*Share diff --git a/ent/share/share.go b/ent/share/share.go new file mode 100644 index 00000000..8327fe6e --- /dev/null +++ b/ent/share/share.go @@ -0,0 +1,185 @@ +// Code generated by ent, DO NOT EDIT. + +package share + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the share type in the database. + Label = "share" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldPassword holds the string denoting the password field in the database. + FieldPassword = "password" + // FieldViews holds the string denoting the views field in the database. + FieldViews = "views" + // FieldDownloads holds the string denoting the downloads field in the database. + FieldDownloads = "downloads" + // FieldExpires holds the string denoting the expires field in the database. + FieldExpires = "expires" + // FieldRemainDownloads holds the string denoting the remain_downloads field in the database. + FieldRemainDownloads = "remain_downloads" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeFile holds the string denoting the file edge name in mutations. + EdgeFile = "file" + // Table holds the table name of the share in the database. + Table = "shares" + // UserTable is the table that holds the user relation/edge. + UserTable = "shares" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_shares" + // FileTable is the table that holds the file relation/edge. + FileTable = "shares" + // FileInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FileInverseTable = "files" + // FileColumn is the table column denoting the file relation/edge. + FileColumn = "file_shares" +) + +// Columns holds all SQL columns for share fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldPassword, + FieldViews, + FieldDownloads, + FieldExpires, + FieldRemainDownloads, +} + +// ForeignKeys holds the SQL foreign-keys that are owned by the "shares" +// table and are not defined as standalone fields in the schema. +var ForeignKeys = []string{ + "file_shares", + "user_shares", +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + for i := range ForeignKeys { + if column == ForeignKeys[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultViews holds the default value on creation for the "views" field. + DefaultViews int + // DefaultDownloads holds the default value on creation for the "downloads" field. + DefaultDownloads int +) + +// OrderOption defines the ordering options for the Share queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByViews orders the results by the views field. +func ByViews(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldViews, opts...).ToFunc() +} + +// ByDownloads orders the results by the downloads field. +func ByDownloads(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDownloads, opts...).ToFunc() +} + +// ByExpires orders the results by the expires field. +func ByExpires(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpires, opts...).ToFunc() +} + +// ByRemainDownloads orders the results by the remain_downloads field. +func ByRemainDownloads(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRemainDownloads, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByFileField orders the results by file field. +func ByFileField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFileStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newFileStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FileInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) +} diff --git a/ent/share/where.go b/ent/share/where.go new file mode 100644 index 00000000..f9361324 --- /dev/null +++ b/ent/share/where.go @@ -0,0 +1,542 @@ +// Code generated by ent, DO NOT EDIT. + +package share + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Share { + return predicate.Share(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Share { + return predicate.Share(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Share { + return predicate.Share(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. +func Password(v string) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldPassword, v)) +} + +// Views applies equality check predicate on the "views" field. It's identical to ViewsEQ. +func Views(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldViews, v)) +} + +// Downloads applies equality check predicate on the "downloads" field. It's identical to DownloadsEQ. +func Downloads(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldDownloads, v)) +} + +// Expires applies equality check predicate on the "expires" field. It's identical to ExpiresEQ. +func Expires(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldExpires, v)) +} + +// RemainDownloads applies equality check predicate on the "remain_downloads" field. It's identical to RemainDownloadsEQ. +func RemainDownloads(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldRemainDownloads, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Share { + return predicate.Share(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Share { + return predicate.Share(sql.FieldNotNull(FieldDeletedAt)) +} + +// PasswordEQ applies the EQ predicate on the "password" field. +func PasswordEQ(v string) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldPassword, v)) +} + +// PasswordNEQ applies the NEQ predicate on the "password" field. +func PasswordNEQ(v string) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldPassword, v)) +} + +// PasswordIn applies the In predicate on the "password" field. +func PasswordIn(vs ...string) predicate.Share { + return predicate.Share(sql.FieldIn(FieldPassword, vs...)) +} + +// PasswordNotIn applies the NotIn predicate on the "password" field. +func PasswordNotIn(vs ...string) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldPassword, vs...)) +} + +// PasswordGT applies the GT predicate on the "password" field. +func PasswordGT(v string) predicate.Share { + return predicate.Share(sql.FieldGT(FieldPassword, v)) +} + +// PasswordGTE applies the GTE predicate on the "password" field. +func PasswordGTE(v string) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldPassword, v)) +} + +// PasswordLT applies the LT predicate on the "password" field. +func PasswordLT(v string) predicate.Share { + return predicate.Share(sql.FieldLT(FieldPassword, v)) +} + +// PasswordLTE applies the LTE predicate on the "password" field. +func PasswordLTE(v string) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldPassword, v)) +} + +// PasswordContains applies the Contains predicate on the "password" field. +func PasswordContains(v string) predicate.Share { + return predicate.Share(sql.FieldContains(FieldPassword, v)) +} + +// PasswordHasPrefix applies the HasPrefix predicate on the "password" field. +func PasswordHasPrefix(v string) predicate.Share { + return predicate.Share(sql.FieldHasPrefix(FieldPassword, v)) +} + +// PasswordHasSuffix applies the HasSuffix predicate on the "password" field. +func PasswordHasSuffix(v string) predicate.Share { + return predicate.Share(sql.FieldHasSuffix(FieldPassword, v)) +} + +// PasswordIsNil applies the IsNil predicate on the "password" field. +func PasswordIsNil() predicate.Share { + return predicate.Share(sql.FieldIsNull(FieldPassword)) +} + +// PasswordNotNil applies the NotNil predicate on the "password" field. +func PasswordNotNil() predicate.Share { + return predicate.Share(sql.FieldNotNull(FieldPassword)) +} + +// PasswordEqualFold applies the EqualFold predicate on the "password" field. +func PasswordEqualFold(v string) predicate.Share { + return predicate.Share(sql.FieldEqualFold(FieldPassword, v)) +} + +// PasswordContainsFold applies the ContainsFold predicate on the "password" field. +func PasswordContainsFold(v string) predicate.Share { + return predicate.Share(sql.FieldContainsFold(FieldPassword, v)) +} + +// ViewsEQ applies the EQ predicate on the "views" field. +func ViewsEQ(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldViews, v)) +} + +// ViewsNEQ applies the NEQ predicate on the "views" field. +func ViewsNEQ(v int) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldViews, v)) +} + +// ViewsIn applies the In predicate on the "views" field. +func ViewsIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldIn(FieldViews, vs...)) +} + +// ViewsNotIn applies the NotIn predicate on the "views" field. +func ViewsNotIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldViews, vs...)) +} + +// ViewsGT applies the GT predicate on the "views" field. +func ViewsGT(v int) predicate.Share { + return predicate.Share(sql.FieldGT(FieldViews, v)) +} + +// ViewsGTE applies the GTE predicate on the "views" field. +func ViewsGTE(v int) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldViews, v)) +} + +// ViewsLT applies the LT predicate on the "views" field. +func ViewsLT(v int) predicate.Share { + return predicate.Share(sql.FieldLT(FieldViews, v)) +} + +// ViewsLTE applies the LTE predicate on the "views" field. +func ViewsLTE(v int) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldViews, v)) +} + +// DownloadsEQ applies the EQ predicate on the "downloads" field. +func DownloadsEQ(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldDownloads, v)) +} + +// DownloadsNEQ applies the NEQ predicate on the "downloads" field. +func DownloadsNEQ(v int) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldDownloads, v)) +} + +// DownloadsIn applies the In predicate on the "downloads" field. +func DownloadsIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldIn(FieldDownloads, vs...)) +} + +// DownloadsNotIn applies the NotIn predicate on the "downloads" field. +func DownloadsNotIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldDownloads, vs...)) +} + +// DownloadsGT applies the GT predicate on the "downloads" field. +func DownloadsGT(v int) predicate.Share { + return predicate.Share(sql.FieldGT(FieldDownloads, v)) +} + +// DownloadsGTE applies the GTE predicate on the "downloads" field. +func DownloadsGTE(v int) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldDownloads, v)) +} + +// DownloadsLT applies the LT predicate on the "downloads" field. +func DownloadsLT(v int) predicate.Share { + return predicate.Share(sql.FieldLT(FieldDownloads, v)) +} + +// DownloadsLTE applies the LTE predicate on the "downloads" field. +func DownloadsLTE(v int) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldDownloads, v)) +} + +// ExpiresEQ applies the EQ predicate on the "expires" field. +func ExpiresEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldExpires, v)) +} + +// ExpiresNEQ applies the NEQ predicate on the "expires" field. +func ExpiresNEQ(v time.Time) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldExpires, v)) +} + +// ExpiresIn applies the In predicate on the "expires" field. +func ExpiresIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldIn(FieldExpires, vs...)) +} + +// ExpiresNotIn applies the NotIn predicate on the "expires" field. +func ExpiresNotIn(vs ...time.Time) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldExpires, vs...)) +} + +// ExpiresGT applies the GT predicate on the "expires" field. +func ExpiresGT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGT(FieldExpires, v)) +} + +// ExpiresGTE applies the GTE predicate on the "expires" field. +func ExpiresGTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldExpires, v)) +} + +// ExpiresLT applies the LT predicate on the "expires" field. +func ExpiresLT(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLT(FieldExpires, v)) +} + +// ExpiresLTE applies the LTE predicate on the "expires" field. +func ExpiresLTE(v time.Time) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldExpires, v)) +} + +// ExpiresIsNil applies the IsNil predicate on the "expires" field. +func ExpiresIsNil() predicate.Share { + return predicate.Share(sql.FieldIsNull(FieldExpires)) +} + +// ExpiresNotNil applies the NotNil predicate on the "expires" field. +func ExpiresNotNil() predicate.Share { + return predicate.Share(sql.FieldNotNull(FieldExpires)) +} + +// RemainDownloadsEQ applies the EQ predicate on the "remain_downloads" field. +func RemainDownloadsEQ(v int) predicate.Share { + return predicate.Share(sql.FieldEQ(FieldRemainDownloads, v)) +} + +// RemainDownloadsNEQ applies the NEQ predicate on the "remain_downloads" field. +func RemainDownloadsNEQ(v int) predicate.Share { + return predicate.Share(sql.FieldNEQ(FieldRemainDownloads, v)) +} + +// RemainDownloadsIn applies the In predicate on the "remain_downloads" field. +func RemainDownloadsIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldIn(FieldRemainDownloads, vs...)) +} + +// RemainDownloadsNotIn applies the NotIn predicate on the "remain_downloads" field. +func RemainDownloadsNotIn(vs ...int) predicate.Share { + return predicate.Share(sql.FieldNotIn(FieldRemainDownloads, vs...)) +} + +// RemainDownloadsGT applies the GT predicate on the "remain_downloads" field. +func RemainDownloadsGT(v int) predicate.Share { + return predicate.Share(sql.FieldGT(FieldRemainDownloads, v)) +} + +// RemainDownloadsGTE applies the GTE predicate on the "remain_downloads" field. +func RemainDownloadsGTE(v int) predicate.Share { + return predicate.Share(sql.FieldGTE(FieldRemainDownloads, v)) +} + +// RemainDownloadsLT applies the LT predicate on the "remain_downloads" field. +func RemainDownloadsLT(v int) predicate.Share { + return predicate.Share(sql.FieldLT(FieldRemainDownloads, v)) +} + +// RemainDownloadsLTE applies the LTE predicate on the "remain_downloads" field. +func RemainDownloadsLTE(v int) predicate.Share { + return predicate.Share(sql.FieldLTE(FieldRemainDownloads, v)) +} + +// RemainDownloadsIsNil applies the IsNil predicate on the "remain_downloads" field. +func RemainDownloadsIsNil() predicate.Share { + return predicate.Share(sql.FieldIsNull(FieldRemainDownloads)) +} + +// RemainDownloadsNotNil applies the NotNil predicate on the "remain_downloads" field. +func RemainDownloadsNotNil() predicate.Share { + return predicate.Share(sql.FieldNotNull(FieldRemainDownloads)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.Share { + return predicate.Share(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.Share { + return predicate.Share(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasFile applies the HasEdge predicate on the "file" edge. +func HasFile() predicate.Share { + return predicate.Share(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, FileTable, FileColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFileWith applies the HasEdge predicate on the "file" edge with a given conditions (other predicates). +func HasFileWith(preds ...predicate.File) predicate.Share { + return predicate.Share(func(s *sql.Selector) { + step := newFileStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Share) predicate.Share { + return predicate.Share(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Share) predicate.Share { + return predicate.Share(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Share) predicate.Share { + return predicate.Share(sql.NotPredicates(p)) +} diff --git a/ent/share_create.go b/ent/share_create.go new file mode 100644 index 00000000..31a3525e --- /dev/null +++ b/ent/share_create.go @@ -0,0 +1,1107 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// ShareCreate is the builder for creating a Share entity. +type ShareCreate struct { + config + mutation *ShareMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (sc *ShareCreate) SetCreatedAt(t time.Time) *ShareCreate { + sc.mutation.SetCreatedAt(t) + return sc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (sc *ShareCreate) SetNillableCreatedAt(t *time.Time) *ShareCreate { + if t != nil { + sc.SetCreatedAt(*t) + } + return sc +} + +// SetUpdatedAt sets the "updated_at" field. +func (sc *ShareCreate) SetUpdatedAt(t time.Time) *ShareCreate { + sc.mutation.SetUpdatedAt(t) + return sc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (sc *ShareCreate) SetNillableUpdatedAt(t *time.Time) *ShareCreate { + if t != nil { + sc.SetUpdatedAt(*t) + } + return sc +} + +// SetDeletedAt sets the "deleted_at" field. +func (sc *ShareCreate) SetDeletedAt(t time.Time) *ShareCreate { + sc.mutation.SetDeletedAt(t) + return sc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (sc *ShareCreate) SetNillableDeletedAt(t *time.Time) *ShareCreate { + if t != nil { + sc.SetDeletedAt(*t) + } + return sc +} + +// SetPassword sets the "password" field. +func (sc *ShareCreate) SetPassword(s string) *ShareCreate { + sc.mutation.SetPassword(s) + return sc +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (sc *ShareCreate) SetNillablePassword(s *string) *ShareCreate { + if s != nil { + sc.SetPassword(*s) + } + return sc +} + +// SetViews sets the "views" field. +func (sc *ShareCreate) SetViews(i int) *ShareCreate { + sc.mutation.SetViews(i) + return sc +} + +// SetNillableViews sets the "views" field if the given value is not nil. +func (sc *ShareCreate) SetNillableViews(i *int) *ShareCreate { + if i != nil { + sc.SetViews(*i) + } + return sc +} + +// SetDownloads sets the "downloads" field. +func (sc *ShareCreate) SetDownloads(i int) *ShareCreate { + sc.mutation.SetDownloads(i) + return sc +} + +// SetNillableDownloads sets the "downloads" field if the given value is not nil. +func (sc *ShareCreate) SetNillableDownloads(i *int) *ShareCreate { + if i != nil { + sc.SetDownloads(*i) + } + return sc +} + +// SetExpires sets the "expires" field. +func (sc *ShareCreate) SetExpires(t time.Time) *ShareCreate { + sc.mutation.SetExpires(t) + return sc +} + +// SetNillableExpires sets the "expires" field if the given value is not nil. +func (sc *ShareCreate) SetNillableExpires(t *time.Time) *ShareCreate { + if t != nil { + sc.SetExpires(*t) + } + return sc +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (sc *ShareCreate) SetRemainDownloads(i int) *ShareCreate { + sc.mutation.SetRemainDownloads(i) + return sc +} + +// SetNillableRemainDownloads sets the "remain_downloads" field if the given value is not nil. +func (sc *ShareCreate) SetNillableRemainDownloads(i *int) *ShareCreate { + if i != nil { + sc.SetRemainDownloads(*i) + } + return sc +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (sc *ShareCreate) SetUserID(id int) *ShareCreate { + sc.mutation.SetUserID(id) + return sc +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (sc *ShareCreate) SetNillableUserID(id *int) *ShareCreate { + if id != nil { + sc = sc.SetUserID(*id) + } + return sc +} + +// SetUser sets the "user" edge to the User entity. +func (sc *ShareCreate) SetUser(u *User) *ShareCreate { + return sc.SetUserID(u.ID) +} + +// SetFileID sets the "file" edge to the File entity by ID. +func (sc *ShareCreate) SetFileID(id int) *ShareCreate { + sc.mutation.SetFileID(id) + return sc +} + +// SetNillableFileID sets the "file" edge to the File entity by ID if the given value is not nil. +func (sc *ShareCreate) SetNillableFileID(id *int) *ShareCreate { + if id != nil { + sc = sc.SetFileID(*id) + } + return sc +} + +// SetFile sets the "file" edge to the File entity. +func (sc *ShareCreate) SetFile(f *File) *ShareCreate { + return sc.SetFileID(f.ID) +} + +// Mutation returns the ShareMutation object of the builder. +func (sc *ShareCreate) Mutation() *ShareMutation { + return sc.mutation +} + +// Save creates the Share in the database. +func (sc *ShareCreate) Save(ctx context.Context) (*Share, error) { + if err := sc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, sc.sqlSave, sc.mutation, sc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (sc *ShareCreate) SaveX(ctx context.Context) *Share { + v, err := sc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (sc *ShareCreate) Exec(ctx context.Context) error { + _, err := sc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (sc *ShareCreate) ExecX(ctx context.Context) { + if err := sc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (sc *ShareCreate) defaults() error { + if _, ok := sc.mutation.CreatedAt(); !ok { + if share.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized share.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := share.DefaultCreatedAt() + sc.mutation.SetCreatedAt(v) + } + if _, ok := sc.mutation.UpdatedAt(); !ok { + if share.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized share.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := share.DefaultUpdatedAt() + sc.mutation.SetUpdatedAt(v) + } + if _, ok := sc.mutation.Views(); !ok { + v := share.DefaultViews + sc.mutation.SetViews(v) + } + if _, ok := sc.mutation.Downloads(); !ok { + v := share.DefaultDownloads + sc.mutation.SetDownloads(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (sc *ShareCreate) check() error { + if _, ok := sc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Share.created_at"`)} + } + if _, ok := sc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Share.updated_at"`)} + } + if _, ok := sc.mutation.Views(); !ok { + return &ValidationError{Name: "views", err: errors.New(`ent: missing required field "Share.views"`)} + } + if _, ok := sc.mutation.Downloads(); !ok { + return &ValidationError{Name: "downloads", err: errors.New(`ent: missing required field "Share.downloads"`)} + } + return nil +} + +func (sc *ShareCreate) sqlSave(ctx context.Context) (*Share, error) { + if err := sc.check(); err != nil { + return nil, err + } + _node, _spec := sc.createSpec() + if err := sqlgraph.CreateNode(ctx, sc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + sc.mutation.id = &_node.ID + sc.mutation.done = true + return _node, nil +} + +func (sc *ShareCreate) createSpec() (*Share, *sqlgraph.CreateSpec) { + var ( + _node = &Share{config: sc.config} + _spec = sqlgraph.NewCreateSpec(share.Table, sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt)) + ) + + if id, ok := sc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = sc.conflict + if value, ok := sc.mutation.CreatedAt(); ok { + _spec.SetField(share.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := sc.mutation.UpdatedAt(); ok { + _spec.SetField(share.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := sc.mutation.DeletedAt(); ok { + _spec.SetField(share.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := sc.mutation.Password(); ok { + _spec.SetField(share.FieldPassword, field.TypeString, value) + _node.Password = value + } + if value, ok := sc.mutation.Views(); ok { + _spec.SetField(share.FieldViews, field.TypeInt, value) + _node.Views = value + } + if value, ok := sc.mutation.Downloads(); ok { + _spec.SetField(share.FieldDownloads, field.TypeInt, value) + _node.Downloads = value + } + if value, ok := sc.mutation.Expires(); ok { + _spec.SetField(share.FieldExpires, field.TypeTime, value) + _node.Expires = &value + } + if value, ok := sc.mutation.RemainDownloads(); ok { + _spec.SetField(share.FieldRemainDownloads, field.TypeInt, value) + _node.RemainDownloads = &value + } + if nodes := sc.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.UserTable, + Columns: []string{share.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.user_shares = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := sc.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.FileTable, + Columns: []string{share.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.file_shares = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Share.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ShareUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (sc *ShareCreate) OnConflict(opts ...sql.ConflictOption) *ShareUpsertOne { + sc.conflict = opts + return &ShareUpsertOne{ + create: sc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (sc *ShareCreate) OnConflictColumns(columns ...string) *ShareUpsertOne { + sc.conflict = append(sc.conflict, sql.ConflictColumns(columns...)) + return &ShareUpsertOne{ + create: sc, + } +} + +type ( + // ShareUpsertOne is the builder for "upsert"-ing + // one Share node. + ShareUpsertOne struct { + create *ShareCreate + } + + // ShareUpsert is the "OnConflict" setter. + ShareUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ShareUpsert) SetUpdatedAt(v time.Time) *ShareUpsert { + u.Set(share.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ShareUpsert) UpdateUpdatedAt() *ShareUpsert { + u.SetExcluded(share.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *ShareUpsert) SetDeletedAt(v time.Time) *ShareUpsert { + u.Set(share.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *ShareUpsert) UpdateDeletedAt() *ShareUpsert { + u.SetExcluded(share.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *ShareUpsert) ClearDeletedAt() *ShareUpsert { + u.SetNull(share.FieldDeletedAt) + return u +} + +// SetPassword sets the "password" field. +func (u *ShareUpsert) SetPassword(v string) *ShareUpsert { + u.Set(share.FieldPassword, v) + return u +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *ShareUpsert) UpdatePassword() *ShareUpsert { + u.SetExcluded(share.FieldPassword) + return u +} + +// ClearPassword clears the value of the "password" field. +func (u *ShareUpsert) ClearPassword() *ShareUpsert { + u.SetNull(share.FieldPassword) + return u +} + +// SetViews sets the "views" field. +func (u *ShareUpsert) SetViews(v int) *ShareUpsert { + u.Set(share.FieldViews, v) + return u +} + +// UpdateViews sets the "views" field to the value that was provided on create. +func (u *ShareUpsert) UpdateViews() *ShareUpsert { + u.SetExcluded(share.FieldViews) + return u +} + +// AddViews adds v to the "views" field. +func (u *ShareUpsert) AddViews(v int) *ShareUpsert { + u.Add(share.FieldViews, v) + return u +} + +// SetDownloads sets the "downloads" field. +func (u *ShareUpsert) SetDownloads(v int) *ShareUpsert { + u.Set(share.FieldDownloads, v) + return u +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *ShareUpsert) UpdateDownloads() *ShareUpsert { + u.SetExcluded(share.FieldDownloads) + return u +} + +// AddDownloads adds v to the "downloads" field. +func (u *ShareUpsert) AddDownloads(v int) *ShareUpsert { + u.Add(share.FieldDownloads, v) + return u +} + +// SetExpires sets the "expires" field. +func (u *ShareUpsert) SetExpires(v time.Time) *ShareUpsert { + u.Set(share.FieldExpires, v) + return u +} + +// UpdateExpires sets the "expires" field to the value that was provided on create. +func (u *ShareUpsert) UpdateExpires() *ShareUpsert { + u.SetExcluded(share.FieldExpires) + return u +} + +// ClearExpires clears the value of the "expires" field. +func (u *ShareUpsert) ClearExpires() *ShareUpsert { + u.SetNull(share.FieldExpires) + return u +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (u *ShareUpsert) SetRemainDownloads(v int) *ShareUpsert { + u.Set(share.FieldRemainDownloads, v) + return u +} + +// UpdateRemainDownloads sets the "remain_downloads" field to the value that was provided on create. +func (u *ShareUpsert) UpdateRemainDownloads() *ShareUpsert { + u.SetExcluded(share.FieldRemainDownloads) + return u +} + +// AddRemainDownloads adds v to the "remain_downloads" field. +func (u *ShareUpsert) AddRemainDownloads(v int) *ShareUpsert { + u.Add(share.FieldRemainDownloads, v) + return u +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (u *ShareUpsert) ClearRemainDownloads() *ShareUpsert { + u.SetNull(share.FieldRemainDownloads) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ShareUpsertOne) UpdateNewValues() *ShareUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(share.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ShareUpsertOne) Ignore() *ShareUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ShareUpsertOne) DoNothing() *ShareUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ShareCreate.OnConflict +// documentation for more info. +func (u *ShareUpsertOne) Update(set func(*ShareUpsert)) *ShareUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ShareUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ShareUpsertOne) SetUpdatedAt(v time.Time) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateUpdatedAt() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *ShareUpsertOne) SetDeletedAt(v time.Time) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateDeletedAt() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *ShareUpsertOne) ClearDeletedAt() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.ClearDeletedAt() + }) +} + +// SetPassword sets the "password" field. +func (u *ShareUpsertOne) SetPassword(v string) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdatePassword() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdatePassword() + }) +} + +// ClearPassword clears the value of the "password" field. +func (u *ShareUpsertOne) ClearPassword() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.ClearPassword() + }) +} + +// SetViews sets the "views" field. +func (u *ShareUpsertOne) SetViews(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetViews(v) + }) +} + +// AddViews adds v to the "views" field. +func (u *ShareUpsertOne) AddViews(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.AddViews(v) + }) +} + +// UpdateViews sets the "views" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateViews() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateViews() + }) +} + +// SetDownloads sets the "downloads" field. +func (u *ShareUpsertOne) SetDownloads(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetDownloads(v) + }) +} + +// AddDownloads adds v to the "downloads" field. +func (u *ShareUpsertOne) AddDownloads(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.AddDownloads(v) + }) +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateDownloads() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateDownloads() + }) +} + +// SetExpires sets the "expires" field. +func (u *ShareUpsertOne) SetExpires(v time.Time) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetExpires(v) + }) +} + +// UpdateExpires sets the "expires" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateExpires() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateExpires() + }) +} + +// ClearExpires clears the value of the "expires" field. +func (u *ShareUpsertOne) ClearExpires() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.ClearExpires() + }) +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (u *ShareUpsertOne) SetRemainDownloads(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.SetRemainDownloads(v) + }) +} + +// AddRemainDownloads adds v to the "remain_downloads" field. +func (u *ShareUpsertOne) AddRemainDownloads(v int) *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.AddRemainDownloads(v) + }) +} + +// UpdateRemainDownloads sets the "remain_downloads" field to the value that was provided on create. +func (u *ShareUpsertOne) UpdateRemainDownloads() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.UpdateRemainDownloads() + }) +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (u *ShareUpsertOne) ClearRemainDownloads() *ShareUpsertOne { + return u.Update(func(s *ShareUpsert) { + s.ClearRemainDownloads() + }) +} + +// Exec executes the query. +func (u *ShareUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ShareCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ShareUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ShareUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ShareUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *ShareCreate) SetRawID(t int) *ShareCreate { + m.mutation.SetRawID(t) + return m +} + +// ShareCreateBulk is the builder for creating many Share entities in bulk. +type ShareCreateBulk struct { + config + err error + builders []*ShareCreate + conflict []sql.ConflictOption +} + +// Save creates the Share entities in the database. +func (scb *ShareCreateBulk) Save(ctx context.Context) ([]*Share, error) { + if scb.err != nil { + return nil, scb.err + } + specs := make([]*sqlgraph.CreateSpec, len(scb.builders)) + nodes := make([]*Share, len(scb.builders)) + mutators := make([]Mutator, len(scb.builders)) + for i := range scb.builders { + func(i int, root context.Context) { + builder := scb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ShareMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, scb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = scb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, scb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, scb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (scb *ShareCreateBulk) SaveX(ctx context.Context) []*Share { + v, err := scb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (scb *ShareCreateBulk) Exec(ctx context.Context) error { + _, err := scb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (scb *ShareCreateBulk) ExecX(ctx context.Context) { + if err := scb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Share.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ShareUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (scb *ShareCreateBulk) OnConflict(opts ...sql.ConflictOption) *ShareUpsertBulk { + scb.conflict = opts + return &ShareUpsertBulk{ + create: scb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (scb *ShareCreateBulk) OnConflictColumns(columns ...string) *ShareUpsertBulk { + scb.conflict = append(scb.conflict, sql.ConflictColumns(columns...)) + return &ShareUpsertBulk{ + create: scb, + } +} + +// ShareUpsertBulk is the builder for "upsert"-ing +// a bulk of Share nodes. +type ShareUpsertBulk struct { + create *ShareCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ShareUpsertBulk) UpdateNewValues() *ShareUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(share.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Share.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ShareUpsertBulk) Ignore() *ShareUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ShareUpsertBulk) DoNothing() *ShareUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ShareCreateBulk.OnConflict +// documentation for more info. +func (u *ShareUpsertBulk) Update(set func(*ShareUpsert)) *ShareUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ShareUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ShareUpsertBulk) SetUpdatedAt(v time.Time) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateUpdatedAt() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *ShareUpsertBulk) SetDeletedAt(v time.Time) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateDeletedAt() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *ShareUpsertBulk) ClearDeletedAt() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.ClearDeletedAt() + }) +} + +// SetPassword sets the "password" field. +func (u *ShareUpsertBulk) SetPassword(v string) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdatePassword() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdatePassword() + }) +} + +// ClearPassword clears the value of the "password" field. +func (u *ShareUpsertBulk) ClearPassword() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.ClearPassword() + }) +} + +// SetViews sets the "views" field. +func (u *ShareUpsertBulk) SetViews(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetViews(v) + }) +} + +// AddViews adds v to the "views" field. +func (u *ShareUpsertBulk) AddViews(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.AddViews(v) + }) +} + +// UpdateViews sets the "views" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateViews() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateViews() + }) +} + +// SetDownloads sets the "downloads" field. +func (u *ShareUpsertBulk) SetDownloads(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetDownloads(v) + }) +} + +// AddDownloads adds v to the "downloads" field. +func (u *ShareUpsertBulk) AddDownloads(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.AddDownloads(v) + }) +} + +// UpdateDownloads sets the "downloads" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateDownloads() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateDownloads() + }) +} + +// SetExpires sets the "expires" field. +func (u *ShareUpsertBulk) SetExpires(v time.Time) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetExpires(v) + }) +} + +// UpdateExpires sets the "expires" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateExpires() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateExpires() + }) +} + +// ClearExpires clears the value of the "expires" field. +func (u *ShareUpsertBulk) ClearExpires() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.ClearExpires() + }) +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (u *ShareUpsertBulk) SetRemainDownloads(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.SetRemainDownloads(v) + }) +} + +// AddRemainDownloads adds v to the "remain_downloads" field. +func (u *ShareUpsertBulk) AddRemainDownloads(v int) *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.AddRemainDownloads(v) + }) +} + +// UpdateRemainDownloads sets the "remain_downloads" field to the value that was provided on create. +func (u *ShareUpsertBulk) UpdateRemainDownloads() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.UpdateRemainDownloads() + }) +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (u *ShareUpsertBulk) ClearRemainDownloads() *ShareUpsertBulk { + return u.Update(func(s *ShareUpsert) { + s.ClearRemainDownloads() + }) +} + +// Exec executes the query. +func (u *ShareUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ShareCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ShareCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ShareUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/share_delete.go b/ent/share_delete.go new file mode 100644 index 00000000..20737a83 --- /dev/null +++ b/ent/share_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" +) + +// ShareDelete is the builder for deleting a Share entity. +type ShareDelete struct { + config + hooks []Hook + mutation *ShareMutation +} + +// Where appends a list predicates to the ShareDelete builder. +func (sd *ShareDelete) Where(ps ...predicate.Share) *ShareDelete { + sd.mutation.Where(ps...) + return sd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (sd *ShareDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, sd.sqlExec, sd.mutation, sd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (sd *ShareDelete) ExecX(ctx context.Context) int { + n, err := sd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (sd *ShareDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(share.Table, sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt)) + if ps := sd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, sd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + sd.mutation.done = true + return affected, err +} + +// ShareDeleteOne is the builder for deleting a single Share entity. +type ShareDeleteOne struct { + sd *ShareDelete +} + +// Where appends a list predicates to the ShareDelete builder. +func (sdo *ShareDeleteOne) Where(ps ...predicate.Share) *ShareDeleteOne { + sdo.sd.mutation.Where(ps...) + return sdo +} + +// Exec executes the deletion query. +func (sdo *ShareDeleteOne) Exec(ctx context.Context) error { + n, err := sdo.sd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{share.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (sdo *ShareDeleteOne) ExecX(ctx context.Context) { + if err := sdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/share_query.go b/ent/share_query.go new file mode 100644 index 00000000..b3bc6f53 --- /dev/null +++ b/ent/share_query.go @@ -0,0 +1,688 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// ShareQuery is the builder for querying Share entities. +type ShareQuery struct { + config + ctx *QueryContext + order []share.OrderOption + inters []Interceptor + predicates []predicate.Share + withUser *UserQuery + withFile *FileQuery + withFKs bool + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ShareQuery builder. +func (sq *ShareQuery) Where(ps ...predicate.Share) *ShareQuery { + sq.predicates = append(sq.predicates, ps...) + return sq +} + +// Limit the number of records to be returned by this query. +func (sq *ShareQuery) Limit(limit int) *ShareQuery { + sq.ctx.Limit = &limit + return sq +} + +// Offset to start from. +func (sq *ShareQuery) Offset(offset int) *ShareQuery { + sq.ctx.Offset = &offset + return sq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (sq *ShareQuery) Unique(unique bool) *ShareQuery { + sq.ctx.Unique = &unique + return sq +} + +// Order specifies how the records should be ordered. +func (sq *ShareQuery) Order(o ...share.OrderOption) *ShareQuery { + sq.order = append(sq.order, o...) + return sq +} + +// QueryUser chains the current query on the "user" edge. +func (sq *ShareQuery) QueryUser() *UserQuery { + query := (&UserClient{config: sq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := sq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := sq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(share.Table, share.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, share.UserTable, share.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(sq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryFile chains the current query on the "file" edge. +func (sq *ShareQuery) QueryFile() *FileQuery { + query := (&FileClient{config: sq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := sq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := sq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(share.Table, share.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, share.FileTable, share.FileColumn), + ) + fromU = sqlgraph.SetNeighbors(sq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Share entity from the query. +// Returns a *NotFoundError when no Share was found. +func (sq *ShareQuery) First(ctx context.Context) (*Share, error) { + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{share.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (sq *ShareQuery) FirstX(ctx context.Context) *Share { + node, err := sq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Share ID from the query. +// Returns a *NotFoundError when no Share ID was found. +func (sq *ShareQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{share.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (sq *ShareQuery) FirstIDX(ctx context.Context) int { + id, err := sq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Share entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Share entity is found. +// Returns a *NotFoundError when no Share entities are found. +func (sq *ShareQuery) Only(ctx context.Context) (*Share, error) { + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{share.Label} + default: + return nil, &NotSingularError{share.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (sq *ShareQuery) OnlyX(ctx context.Context) *Share { + node, err := sq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Share ID in the query. +// Returns a *NotSingularError when more than one Share ID is found. +// Returns a *NotFoundError when no entities are found. +func (sq *ShareQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{share.Label} + default: + err = &NotSingularError{share.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (sq *ShareQuery) OnlyIDX(ctx context.Context) int { + id, err := sq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Shares. +func (sq *ShareQuery) All(ctx context.Context) ([]*Share, error) { + ctx = setContextOp(ctx, sq.ctx, "All") + if err := sq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Share, *ShareQuery]() + return withInterceptors[[]*Share](ctx, sq, qr, sq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (sq *ShareQuery) AllX(ctx context.Context) []*Share { + nodes, err := sq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Share IDs. +func (sq *ShareQuery) IDs(ctx context.Context) (ids []int, err error) { + if sq.ctx.Unique == nil && sq.path != nil { + sq.Unique(true) + } + ctx = setContextOp(ctx, sq.ctx, "IDs") + if err = sq.Select(share.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (sq *ShareQuery) IDsX(ctx context.Context) []int { + ids, err := sq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (sq *ShareQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sq.ctx, "Count") + if err := sq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, sq, querierCount[*ShareQuery](), sq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (sq *ShareQuery) CountX(ctx context.Context) int { + count, err := sq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (sq *ShareQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, sq.ctx, "Exist") + switch _, err := sq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (sq *ShareQuery) ExistX(ctx context.Context) bool { + exist, err := sq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ShareQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (sq *ShareQuery) Clone() *ShareQuery { + if sq == nil { + return nil + } + return &ShareQuery{ + config: sq.config, + ctx: sq.ctx.Clone(), + order: append([]share.OrderOption{}, sq.order...), + inters: append([]Interceptor{}, sq.inters...), + predicates: append([]predicate.Share{}, sq.predicates...), + withUser: sq.withUser.Clone(), + withFile: sq.withFile.Clone(), + // clone intermediate query. + sql: sq.sql.Clone(), + path: sq.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (sq *ShareQuery) WithUser(opts ...func(*UserQuery)) *ShareQuery { + query := (&UserClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + sq.withUser = query + return sq +} + +// WithFile tells the query-builder to eager-load the nodes that are connected to +// the "file" edge. The optional arguments are used to configure the query builder of the edge. +func (sq *ShareQuery) WithFile(opts ...func(*FileQuery)) *ShareQuery { + query := (&FileClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + sq.withFile = query + return sq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Share.Query(). +// GroupBy(share.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (sq *ShareQuery) GroupBy(field string, fields ...string) *ShareGroupBy { + sq.ctx.Fields = append([]string{field}, fields...) + grbuild := &ShareGroupBy{build: sq} + grbuild.flds = &sq.ctx.Fields + grbuild.label = share.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Share.Query(). +// Select(share.FieldCreatedAt). +// Scan(ctx, &v) +func (sq *ShareQuery) Select(fields ...string) *ShareSelect { + sq.ctx.Fields = append(sq.ctx.Fields, fields...) + sbuild := &ShareSelect{ShareQuery: sq} + sbuild.label = share.Label + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ShareSelect configured with the given aggregations. +func (sq *ShareQuery) Aggregate(fns ...AggregateFunc) *ShareSelect { + return sq.Select().Aggregate(fns...) +} + +func (sq *ShareQuery) prepareQuery(ctx context.Context) error { + for _, inter := range sq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sq); err != nil { + return err + } + } + } + for _, f := range sq.ctx.Fields { + if !share.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if sq.path != nil { + prev, err := sq.path(ctx) + if err != nil { + return err + } + sq.sql = prev + } + return nil +} + +func (sq *ShareQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Share, error) { + var ( + nodes = []*Share{} + withFKs = sq.withFKs + _spec = sq.querySpec() + loadedTypes = [2]bool{ + sq.withUser != nil, + sq.withFile != nil, + } + ) + if sq.withUser != nil || sq.withFile != nil { + withFKs = true + } + if withFKs { + _spec.Node.Columns = append(_spec.Node.Columns, share.ForeignKeys...) + } + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Share).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Share{config: sq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, sq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := sq.withUser; query != nil { + if err := sq.loadUser(ctx, query, nodes, nil, + func(n *Share, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := sq.withFile; query != nil { + if err := sq.loadFile(ctx, query, nodes, nil, + func(n *Share, e *File) { n.Edges.File = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (sq *ShareQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Share, init func(*Share), assign func(*Share, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Share) + for i := range nodes { + if nodes[i].user_shares == nil { + continue + } + fk := *nodes[i].user_shares + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_shares" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (sq *ShareQuery) loadFile(ctx context.Context, query *FileQuery, nodes []*Share, init func(*Share), assign func(*Share, *File)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Share) + for i := range nodes { + if nodes[i].file_shares == nil { + continue + } + fk := *nodes[i].file_shares + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(file.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_shares" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (sq *ShareQuery) sqlCount(ctx context.Context) (int, error) { + _spec := sq.querySpec() + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, sq.driver, _spec) +} + +func (sq *ShareQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(share.Table, share.Columns, sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt)) + _spec.From = sq.sql + if unique := sq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if sq.path != nil { + _spec.Unique = true + } + if fields := sq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, share.FieldID) + for i := range fields { + if fields[i] != share.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := sq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := sq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := sq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := sq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (sq *ShareQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(sq.driver.Dialect()) + t1 := builder.Table(share.Table) + columns := sq.ctx.Fields + if len(columns) == 0 { + columns = share.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if sq.sql != nil { + selector = sq.sql + selector.Select(selector.Columns(columns...)...) + } + if sq.ctx.Unique != nil && *sq.ctx.Unique { + selector.Distinct() + } + for _, p := range sq.predicates { + p(selector) + } + for _, p := range sq.order { + p(selector) + } + if offset := sq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := sq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ShareGroupBy is the group-by builder for Share entities. +type ShareGroupBy struct { + selector + build *ShareQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (sgb *ShareGroupBy) Aggregate(fns ...AggregateFunc) *ShareGroupBy { + sgb.fns = append(sgb.fns, fns...) + return sgb +} + +// Scan applies the selector query and scans the result into the given value. +func (sgb *ShareGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") + if err := sgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ShareQuery, *ShareGroupBy](ctx, sgb.build, sgb, sgb.build.inters, v) +} + +func (sgb *ShareGroupBy) sqlScan(ctx context.Context, root *ShareQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sgb.fns)) + for _, fn := range sgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sgb.flds)+len(sgb.fns)) + for _, f := range *sgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*sgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := sgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ShareSelect is the builder for selecting fields of Share entities. +type ShareSelect struct { + *ShareQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ss *ShareSelect) Aggregate(fns ...AggregateFunc) *ShareSelect { + ss.fns = append(ss.fns, fns...) + return ss +} + +// Scan applies the selector query and scans the result into the given value. +func (ss *ShareSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ss.ctx, "Select") + if err := ss.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ShareQuery, *ShareSelect](ctx, ss.ShareQuery, ss, ss.inters, v) +} + +func (ss *ShareSelect) sqlScan(ctx context.Context, root *ShareQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ss.fns)) + for _, fn := range ss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ss.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/share_update.go b/ent/share_update.go new file mode 100644 index 00000000..b02e89b6 --- /dev/null +++ b/ent/share_update.go @@ -0,0 +1,778 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// ShareUpdate is the builder for updating Share entities. +type ShareUpdate struct { + config + hooks []Hook + mutation *ShareMutation +} + +// Where appends a list predicates to the ShareUpdate builder. +func (su *ShareUpdate) Where(ps ...predicate.Share) *ShareUpdate { + su.mutation.Where(ps...) + return su +} + +// SetUpdatedAt sets the "updated_at" field. +func (su *ShareUpdate) SetUpdatedAt(t time.Time) *ShareUpdate { + su.mutation.SetUpdatedAt(t) + return su +} + +// SetDeletedAt sets the "deleted_at" field. +func (su *ShareUpdate) SetDeletedAt(t time.Time) *ShareUpdate { + su.mutation.SetDeletedAt(t) + return su +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (su *ShareUpdate) SetNillableDeletedAt(t *time.Time) *ShareUpdate { + if t != nil { + su.SetDeletedAt(*t) + } + return su +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (su *ShareUpdate) ClearDeletedAt() *ShareUpdate { + su.mutation.ClearDeletedAt() + return su +} + +// SetPassword sets the "password" field. +func (su *ShareUpdate) SetPassword(s string) *ShareUpdate { + su.mutation.SetPassword(s) + return su +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (su *ShareUpdate) SetNillablePassword(s *string) *ShareUpdate { + if s != nil { + su.SetPassword(*s) + } + return su +} + +// ClearPassword clears the value of the "password" field. +func (su *ShareUpdate) ClearPassword() *ShareUpdate { + su.mutation.ClearPassword() + return su +} + +// SetViews sets the "views" field. +func (su *ShareUpdate) SetViews(i int) *ShareUpdate { + su.mutation.ResetViews() + su.mutation.SetViews(i) + return su +} + +// SetNillableViews sets the "views" field if the given value is not nil. +func (su *ShareUpdate) SetNillableViews(i *int) *ShareUpdate { + if i != nil { + su.SetViews(*i) + } + return su +} + +// AddViews adds i to the "views" field. +func (su *ShareUpdate) AddViews(i int) *ShareUpdate { + su.mutation.AddViews(i) + return su +} + +// SetDownloads sets the "downloads" field. +func (su *ShareUpdate) SetDownloads(i int) *ShareUpdate { + su.mutation.ResetDownloads() + su.mutation.SetDownloads(i) + return su +} + +// SetNillableDownloads sets the "downloads" field if the given value is not nil. +func (su *ShareUpdate) SetNillableDownloads(i *int) *ShareUpdate { + if i != nil { + su.SetDownloads(*i) + } + return su +} + +// AddDownloads adds i to the "downloads" field. +func (su *ShareUpdate) AddDownloads(i int) *ShareUpdate { + su.mutation.AddDownloads(i) + return su +} + +// SetExpires sets the "expires" field. +func (su *ShareUpdate) SetExpires(t time.Time) *ShareUpdate { + su.mutation.SetExpires(t) + return su +} + +// SetNillableExpires sets the "expires" field if the given value is not nil. +func (su *ShareUpdate) SetNillableExpires(t *time.Time) *ShareUpdate { + if t != nil { + su.SetExpires(*t) + } + return su +} + +// ClearExpires clears the value of the "expires" field. +func (su *ShareUpdate) ClearExpires() *ShareUpdate { + su.mutation.ClearExpires() + return su +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (su *ShareUpdate) SetRemainDownloads(i int) *ShareUpdate { + su.mutation.ResetRemainDownloads() + su.mutation.SetRemainDownloads(i) + return su +} + +// SetNillableRemainDownloads sets the "remain_downloads" field if the given value is not nil. +func (su *ShareUpdate) SetNillableRemainDownloads(i *int) *ShareUpdate { + if i != nil { + su.SetRemainDownloads(*i) + } + return su +} + +// AddRemainDownloads adds i to the "remain_downloads" field. +func (su *ShareUpdate) AddRemainDownloads(i int) *ShareUpdate { + su.mutation.AddRemainDownloads(i) + return su +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (su *ShareUpdate) ClearRemainDownloads() *ShareUpdate { + su.mutation.ClearRemainDownloads() + return su +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (su *ShareUpdate) SetUserID(id int) *ShareUpdate { + su.mutation.SetUserID(id) + return su +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (su *ShareUpdate) SetNillableUserID(id *int) *ShareUpdate { + if id != nil { + su = su.SetUserID(*id) + } + return su +} + +// SetUser sets the "user" edge to the User entity. +func (su *ShareUpdate) SetUser(u *User) *ShareUpdate { + return su.SetUserID(u.ID) +} + +// SetFileID sets the "file" edge to the File entity by ID. +func (su *ShareUpdate) SetFileID(id int) *ShareUpdate { + su.mutation.SetFileID(id) + return su +} + +// SetNillableFileID sets the "file" edge to the File entity by ID if the given value is not nil. +func (su *ShareUpdate) SetNillableFileID(id *int) *ShareUpdate { + if id != nil { + su = su.SetFileID(*id) + } + return su +} + +// SetFile sets the "file" edge to the File entity. +func (su *ShareUpdate) SetFile(f *File) *ShareUpdate { + return su.SetFileID(f.ID) +} + +// Mutation returns the ShareMutation object of the builder. +func (su *ShareUpdate) Mutation() *ShareMutation { + return su.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (su *ShareUpdate) ClearUser() *ShareUpdate { + su.mutation.ClearUser() + return su +} + +// ClearFile clears the "file" edge to the File entity. +func (su *ShareUpdate) ClearFile() *ShareUpdate { + su.mutation.ClearFile() + return su +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (su *ShareUpdate) Save(ctx context.Context) (int, error) { + if err := su.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, su.sqlSave, su.mutation, su.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (su *ShareUpdate) SaveX(ctx context.Context) int { + affected, err := su.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (su *ShareUpdate) Exec(ctx context.Context) error { + _, err := su.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (su *ShareUpdate) ExecX(ctx context.Context) { + if err := su.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (su *ShareUpdate) defaults() error { + if _, ok := su.mutation.UpdatedAt(); !ok { + if share.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized share.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := share.UpdateDefaultUpdatedAt() + su.mutation.SetUpdatedAt(v) + } + return nil +} + +func (su *ShareUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(share.Table, share.Columns, sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt)) + if ps := su.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := su.mutation.UpdatedAt(); ok { + _spec.SetField(share.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := su.mutation.DeletedAt(); ok { + _spec.SetField(share.FieldDeletedAt, field.TypeTime, value) + } + if su.mutation.DeletedAtCleared() { + _spec.ClearField(share.FieldDeletedAt, field.TypeTime) + } + if value, ok := su.mutation.Password(); ok { + _spec.SetField(share.FieldPassword, field.TypeString, value) + } + if su.mutation.PasswordCleared() { + _spec.ClearField(share.FieldPassword, field.TypeString) + } + if value, ok := su.mutation.Views(); ok { + _spec.SetField(share.FieldViews, field.TypeInt, value) + } + if value, ok := su.mutation.AddedViews(); ok { + _spec.AddField(share.FieldViews, field.TypeInt, value) + } + if value, ok := su.mutation.Downloads(); ok { + _spec.SetField(share.FieldDownloads, field.TypeInt, value) + } + if value, ok := su.mutation.AddedDownloads(); ok { + _spec.AddField(share.FieldDownloads, field.TypeInt, value) + } + if value, ok := su.mutation.Expires(); ok { + _spec.SetField(share.FieldExpires, field.TypeTime, value) + } + if su.mutation.ExpiresCleared() { + _spec.ClearField(share.FieldExpires, field.TypeTime) + } + if value, ok := su.mutation.RemainDownloads(); ok { + _spec.SetField(share.FieldRemainDownloads, field.TypeInt, value) + } + if value, ok := su.mutation.AddedRemainDownloads(); ok { + _spec.AddField(share.FieldRemainDownloads, field.TypeInt, value) + } + if su.mutation.RemainDownloadsCleared() { + _spec.ClearField(share.FieldRemainDownloads, field.TypeInt) + } + if su.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.UserTable, + Columns: []string{share.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := su.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.UserTable, + Columns: []string{share.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if su.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.FileTable, + Columns: []string{share.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := su.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.FileTable, + Columns: []string{share.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, su.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{share.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + su.mutation.done = true + return n, nil +} + +// ShareUpdateOne is the builder for updating a single Share entity. +type ShareUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ShareMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (suo *ShareUpdateOne) SetUpdatedAt(t time.Time) *ShareUpdateOne { + suo.mutation.SetUpdatedAt(t) + return suo +} + +// SetDeletedAt sets the "deleted_at" field. +func (suo *ShareUpdateOne) SetDeletedAt(t time.Time) *ShareUpdateOne { + suo.mutation.SetDeletedAt(t) + return suo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableDeletedAt(t *time.Time) *ShareUpdateOne { + if t != nil { + suo.SetDeletedAt(*t) + } + return suo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (suo *ShareUpdateOne) ClearDeletedAt() *ShareUpdateOne { + suo.mutation.ClearDeletedAt() + return suo +} + +// SetPassword sets the "password" field. +func (suo *ShareUpdateOne) SetPassword(s string) *ShareUpdateOne { + suo.mutation.SetPassword(s) + return suo +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillablePassword(s *string) *ShareUpdateOne { + if s != nil { + suo.SetPassword(*s) + } + return suo +} + +// ClearPassword clears the value of the "password" field. +func (suo *ShareUpdateOne) ClearPassword() *ShareUpdateOne { + suo.mutation.ClearPassword() + return suo +} + +// SetViews sets the "views" field. +func (suo *ShareUpdateOne) SetViews(i int) *ShareUpdateOne { + suo.mutation.ResetViews() + suo.mutation.SetViews(i) + return suo +} + +// SetNillableViews sets the "views" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableViews(i *int) *ShareUpdateOne { + if i != nil { + suo.SetViews(*i) + } + return suo +} + +// AddViews adds i to the "views" field. +func (suo *ShareUpdateOne) AddViews(i int) *ShareUpdateOne { + suo.mutation.AddViews(i) + return suo +} + +// SetDownloads sets the "downloads" field. +func (suo *ShareUpdateOne) SetDownloads(i int) *ShareUpdateOne { + suo.mutation.ResetDownloads() + suo.mutation.SetDownloads(i) + return suo +} + +// SetNillableDownloads sets the "downloads" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableDownloads(i *int) *ShareUpdateOne { + if i != nil { + suo.SetDownloads(*i) + } + return suo +} + +// AddDownloads adds i to the "downloads" field. +func (suo *ShareUpdateOne) AddDownloads(i int) *ShareUpdateOne { + suo.mutation.AddDownloads(i) + return suo +} + +// SetExpires sets the "expires" field. +func (suo *ShareUpdateOne) SetExpires(t time.Time) *ShareUpdateOne { + suo.mutation.SetExpires(t) + return suo +} + +// SetNillableExpires sets the "expires" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableExpires(t *time.Time) *ShareUpdateOne { + if t != nil { + suo.SetExpires(*t) + } + return suo +} + +// ClearExpires clears the value of the "expires" field. +func (suo *ShareUpdateOne) ClearExpires() *ShareUpdateOne { + suo.mutation.ClearExpires() + return suo +} + +// SetRemainDownloads sets the "remain_downloads" field. +func (suo *ShareUpdateOne) SetRemainDownloads(i int) *ShareUpdateOne { + suo.mutation.ResetRemainDownloads() + suo.mutation.SetRemainDownloads(i) + return suo +} + +// SetNillableRemainDownloads sets the "remain_downloads" field if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableRemainDownloads(i *int) *ShareUpdateOne { + if i != nil { + suo.SetRemainDownloads(*i) + } + return suo +} + +// AddRemainDownloads adds i to the "remain_downloads" field. +func (suo *ShareUpdateOne) AddRemainDownloads(i int) *ShareUpdateOne { + suo.mutation.AddRemainDownloads(i) + return suo +} + +// ClearRemainDownloads clears the value of the "remain_downloads" field. +func (suo *ShareUpdateOne) ClearRemainDownloads() *ShareUpdateOne { + suo.mutation.ClearRemainDownloads() + return suo +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (suo *ShareUpdateOne) SetUserID(id int) *ShareUpdateOne { + suo.mutation.SetUserID(id) + return suo +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableUserID(id *int) *ShareUpdateOne { + if id != nil { + suo = suo.SetUserID(*id) + } + return suo +} + +// SetUser sets the "user" edge to the User entity. +func (suo *ShareUpdateOne) SetUser(u *User) *ShareUpdateOne { + return suo.SetUserID(u.ID) +} + +// SetFileID sets the "file" edge to the File entity by ID. +func (suo *ShareUpdateOne) SetFileID(id int) *ShareUpdateOne { + suo.mutation.SetFileID(id) + return suo +} + +// SetNillableFileID sets the "file" edge to the File entity by ID if the given value is not nil. +func (suo *ShareUpdateOne) SetNillableFileID(id *int) *ShareUpdateOne { + if id != nil { + suo = suo.SetFileID(*id) + } + return suo +} + +// SetFile sets the "file" edge to the File entity. +func (suo *ShareUpdateOne) SetFile(f *File) *ShareUpdateOne { + return suo.SetFileID(f.ID) +} + +// Mutation returns the ShareMutation object of the builder. +func (suo *ShareUpdateOne) Mutation() *ShareMutation { + return suo.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (suo *ShareUpdateOne) ClearUser() *ShareUpdateOne { + suo.mutation.ClearUser() + return suo +} + +// ClearFile clears the "file" edge to the File entity. +func (suo *ShareUpdateOne) ClearFile() *ShareUpdateOne { + suo.mutation.ClearFile() + return suo +} + +// Where appends a list predicates to the ShareUpdate builder. +func (suo *ShareUpdateOne) Where(ps ...predicate.Share) *ShareUpdateOne { + suo.mutation.Where(ps...) + return suo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (suo *ShareUpdateOne) Select(field string, fields ...string) *ShareUpdateOne { + suo.fields = append([]string{field}, fields...) + return suo +} + +// Save executes the query and returns the updated Share entity. +func (suo *ShareUpdateOne) Save(ctx context.Context) (*Share, error) { + if err := suo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, suo.sqlSave, suo.mutation, suo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (suo *ShareUpdateOne) SaveX(ctx context.Context) *Share { + node, err := suo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (suo *ShareUpdateOne) Exec(ctx context.Context) error { + _, err := suo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (suo *ShareUpdateOne) ExecX(ctx context.Context) { + if err := suo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (suo *ShareUpdateOne) defaults() error { + if _, ok := suo.mutation.UpdatedAt(); !ok { + if share.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized share.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := share.UpdateDefaultUpdatedAt() + suo.mutation.SetUpdatedAt(v) + } + return nil +} + +func (suo *ShareUpdateOne) sqlSave(ctx context.Context) (_node *Share, err error) { + _spec := sqlgraph.NewUpdateSpec(share.Table, share.Columns, sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt)) + id, ok := suo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Share.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := suo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, share.FieldID) + for _, f := range fields { + if !share.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != share.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := suo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := suo.mutation.UpdatedAt(); ok { + _spec.SetField(share.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := suo.mutation.DeletedAt(); ok { + _spec.SetField(share.FieldDeletedAt, field.TypeTime, value) + } + if suo.mutation.DeletedAtCleared() { + _spec.ClearField(share.FieldDeletedAt, field.TypeTime) + } + if value, ok := suo.mutation.Password(); ok { + _spec.SetField(share.FieldPassword, field.TypeString, value) + } + if suo.mutation.PasswordCleared() { + _spec.ClearField(share.FieldPassword, field.TypeString) + } + if value, ok := suo.mutation.Views(); ok { + _spec.SetField(share.FieldViews, field.TypeInt, value) + } + if value, ok := suo.mutation.AddedViews(); ok { + _spec.AddField(share.FieldViews, field.TypeInt, value) + } + if value, ok := suo.mutation.Downloads(); ok { + _spec.SetField(share.FieldDownloads, field.TypeInt, value) + } + if value, ok := suo.mutation.AddedDownloads(); ok { + _spec.AddField(share.FieldDownloads, field.TypeInt, value) + } + if value, ok := suo.mutation.Expires(); ok { + _spec.SetField(share.FieldExpires, field.TypeTime, value) + } + if suo.mutation.ExpiresCleared() { + _spec.ClearField(share.FieldExpires, field.TypeTime) + } + if value, ok := suo.mutation.RemainDownloads(); ok { + _spec.SetField(share.FieldRemainDownloads, field.TypeInt, value) + } + if value, ok := suo.mutation.AddedRemainDownloads(); ok { + _spec.AddField(share.FieldRemainDownloads, field.TypeInt, value) + } + if suo.mutation.RemainDownloadsCleared() { + _spec.ClearField(share.FieldRemainDownloads, field.TypeInt) + } + if suo.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.UserTable, + Columns: []string{share.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := suo.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.UserTable, + Columns: []string{share.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if suo.mutation.FileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.FileTable, + Columns: []string{share.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := suo.mutation.FileIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: share.FileTable, + Columns: []string{share.FileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Share{config: suo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, suo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{share.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + suo.mutation.done = true + return _node, nil +} diff --git a/ent/storagepolicy.go b/ent/storagepolicy.go new file mode 100644 index 00000000..820b52ab --- /dev/null +++ b/ent/storagepolicy.go @@ -0,0 +1,396 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// StoragePolicy is the model entity for the StoragePolicy schema. +type StoragePolicy struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Type holds the value of the "type" field. + Type string `json:"type,omitempty"` + // Server holds the value of the "server" field. + Server string `json:"server,omitempty"` + // BucketName holds the value of the "bucket_name" field. + BucketName string `json:"bucket_name,omitempty"` + // IsPrivate holds the value of the "is_private" field. + IsPrivate bool `json:"is_private,omitempty"` + // AccessKey holds the value of the "access_key" field. + AccessKey string `json:"access_key,omitempty"` + // SecretKey holds the value of the "secret_key" field. + SecretKey string `json:"secret_key,omitempty"` + // MaxSize holds the value of the "max_size" field. + MaxSize int64 `json:"max_size,omitempty"` + // DirNameRule holds the value of the "dir_name_rule" field. + DirNameRule string `json:"dir_name_rule,omitempty"` + // FileNameRule holds the value of the "file_name_rule" field. + FileNameRule string `json:"file_name_rule,omitempty"` + // Settings holds the value of the "settings" field. + Settings *types.PolicySetting `json:"settings,omitempty"` + // NodeID holds the value of the "node_id" field. + NodeID int `json:"node_id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the StoragePolicyQuery when eager-loading is set. + Edges StoragePolicyEdges `json:"edges"` + selectValues sql.SelectValues +} + +// StoragePolicyEdges holds the relations/edges for other nodes in the graph. +type StoragePolicyEdges struct { + // Users holds the value of the users edge. + Users []*User `json:"users,omitempty"` + // Groups holds the value of the groups edge. + Groups []*Group `json:"groups,omitempty"` + // Files holds the value of the files edge. + Files []*File `json:"files,omitempty"` + // Entities holds the value of the entities edge. + Entities []*Entity `json:"entities,omitempty"` + // Node holds the value of the node edge. + Node *Node `json:"node,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [5]bool +} + +// UsersOrErr returns the Users value or an error if the edge +// was not loaded in eager-loading. +func (e StoragePolicyEdges) UsersOrErr() ([]*User, error) { + if e.loadedTypes[0] { + return e.Users, nil + } + return nil, &NotLoadedError{edge: "users"} +} + +// GroupsOrErr returns the Groups value or an error if the edge +// was not loaded in eager-loading. +func (e StoragePolicyEdges) GroupsOrErr() ([]*Group, error) { + if e.loadedTypes[1] { + return e.Groups, nil + } + return nil, &NotLoadedError{edge: "groups"} +} + +// FilesOrErr returns the Files value or an error if the edge +// was not loaded in eager-loading. +func (e StoragePolicyEdges) FilesOrErr() ([]*File, error) { + if e.loadedTypes[2] { + return e.Files, nil + } + return nil, &NotLoadedError{edge: "files"} +} + +// EntitiesOrErr returns the Entities value or an error if the edge +// was not loaded in eager-loading. +func (e StoragePolicyEdges) EntitiesOrErr() ([]*Entity, error) { + if e.loadedTypes[3] { + return e.Entities, nil + } + return nil, &NotLoadedError{edge: "entities"} +} + +// NodeOrErr returns the Node value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e StoragePolicyEdges) NodeOrErr() (*Node, error) { + if e.loadedTypes[4] { + if e.Node == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: node.Label} + } + return e.Node, nil + } + return nil, &NotLoadedError{edge: "node"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*StoragePolicy) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case storagepolicy.FieldSettings: + values[i] = new([]byte) + case storagepolicy.FieldIsPrivate: + values[i] = new(sql.NullBool) + case storagepolicy.FieldID, storagepolicy.FieldMaxSize, storagepolicy.FieldNodeID: + values[i] = new(sql.NullInt64) + case storagepolicy.FieldName, storagepolicy.FieldType, storagepolicy.FieldServer, storagepolicy.FieldBucketName, storagepolicy.FieldAccessKey, storagepolicy.FieldSecretKey, storagepolicy.FieldDirNameRule, storagepolicy.FieldFileNameRule: + values[i] = new(sql.NullString) + case storagepolicy.FieldCreatedAt, storagepolicy.FieldUpdatedAt, storagepolicy.FieldDeletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the StoragePolicy fields. +func (sp *StoragePolicy) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case storagepolicy.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + sp.ID = int(value.Int64) + case storagepolicy.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + sp.CreatedAt = value.Time + } + case storagepolicy.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + sp.UpdatedAt = value.Time + } + case storagepolicy.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + sp.DeletedAt = new(time.Time) + *sp.DeletedAt = value.Time + } + case storagepolicy.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + sp.Name = value.String + } + case storagepolicy.FieldType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + sp.Type = value.String + } + case storagepolicy.FieldServer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field server", values[i]) + } else if value.Valid { + sp.Server = value.String + } + case storagepolicy.FieldBucketName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field bucket_name", values[i]) + } else if value.Valid { + sp.BucketName = value.String + } + case storagepolicy.FieldIsPrivate: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_private", values[i]) + } else if value.Valid { + sp.IsPrivate = value.Bool + } + case storagepolicy.FieldAccessKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field access_key", values[i]) + } else if value.Valid { + sp.AccessKey = value.String + } + case storagepolicy.FieldSecretKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field secret_key", values[i]) + } else if value.Valid { + sp.SecretKey = value.String + } + case storagepolicy.FieldMaxSize: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field max_size", values[i]) + } else if value.Valid { + sp.MaxSize = value.Int64 + } + case storagepolicy.FieldDirNameRule: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field dir_name_rule", values[i]) + } else if value.Valid { + sp.DirNameRule = value.String + } + case storagepolicy.FieldFileNameRule: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field file_name_rule", values[i]) + } else if value.Valid { + sp.FileNameRule = value.String + } + case storagepolicy.FieldSettings: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field settings", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &sp.Settings); err != nil { + return fmt.Errorf("unmarshal field settings: %w", err) + } + } + case storagepolicy.FieldNodeID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field node_id", values[i]) + } else if value.Valid { + sp.NodeID = int(value.Int64) + } + default: + sp.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the StoragePolicy. +// This includes values selected through modifiers, order, etc. +func (sp *StoragePolicy) Value(name string) (ent.Value, error) { + return sp.selectValues.Get(name) +} + +// QueryUsers queries the "users" edge of the StoragePolicy entity. +func (sp *StoragePolicy) QueryUsers() *UserQuery { + return NewStoragePolicyClient(sp.config).QueryUsers(sp) +} + +// QueryGroups queries the "groups" edge of the StoragePolicy entity. +func (sp *StoragePolicy) QueryGroups() *GroupQuery { + return NewStoragePolicyClient(sp.config).QueryGroups(sp) +} + +// QueryFiles queries the "files" edge of the StoragePolicy entity. +func (sp *StoragePolicy) QueryFiles() *FileQuery { + return NewStoragePolicyClient(sp.config).QueryFiles(sp) +} + +// QueryEntities queries the "entities" edge of the StoragePolicy entity. +func (sp *StoragePolicy) QueryEntities() *EntityQuery { + return NewStoragePolicyClient(sp.config).QueryEntities(sp) +} + +// QueryNode queries the "node" edge of the StoragePolicy entity. +func (sp *StoragePolicy) QueryNode() *NodeQuery { + return NewStoragePolicyClient(sp.config).QueryNode(sp) +} + +// Update returns a builder for updating this StoragePolicy. +// Note that you need to call StoragePolicy.Unwrap() before calling this method if this StoragePolicy +// was returned from a transaction, and the transaction was committed or rolled back. +func (sp *StoragePolicy) Update() *StoragePolicyUpdateOne { + return NewStoragePolicyClient(sp.config).UpdateOne(sp) +} + +// Unwrap unwraps the StoragePolicy entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (sp *StoragePolicy) Unwrap() *StoragePolicy { + _tx, ok := sp.config.driver.(*txDriver) + if !ok { + panic("ent: StoragePolicy is not a transactional entity") + } + sp.config.driver = _tx.drv + return sp +} + +// String implements the fmt.Stringer. +func (sp *StoragePolicy) String() string { + var builder strings.Builder + builder.WriteString("StoragePolicy(") + builder.WriteString(fmt.Sprintf("id=%v, ", sp.ID)) + builder.WriteString("created_at=") + builder.WriteString(sp.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(sp.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := sp.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(sp.Name) + builder.WriteString(", ") + builder.WriteString("type=") + builder.WriteString(sp.Type) + builder.WriteString(", ") + builder.WriteString("server=") + builder.WriteString(sp.Server) + builder.WriteString(", ") + builder.WriteString("bucket_name=") + builder.WriteString(sp.BucketName) + builder.WriteString(", ") + builder.WriteString("is_private=") + builder.WriteString(fmt.Sprintf("%v", sp.IsPrivate)) + builder.WriteString(", ") + builder.WriteString("access_key=") + builder.WriteString(sp.AccessKey) + builder.WriteString(", ") + builder.WriteString("secret_key=") + builder.WriteString(sp.SecretKey) + builder.WriteString(", ") + builder.WriteString("max_size=") + builder.WriteString(fmt.Sprintf("%v", sp.MaxSize)) + builder.WriteString(", ") + builder.WriteString("dir_name_rule=") + builder.WriteString(sp.DirNameRule) + builder.WriteString(", ") + builder.WriteString("file_name_rule=") + builder.WriteString(sp.FileNameRule) + builder.WriteString(", ") + builder.WriteString("settings=") + builder.WriteString(fmt.Sprintf("%v", sp.Settings)) + builder.WriteString(", ") + builder.WriteString("node_id=") + builder.WriteString(fmt.Sprintf("%v", sp.NodeID)) + builder.WriteByte(')') + return builder.String() +} + +// SetUsers manually set the edge as loaded state. +func (e *StoragePolicy) SetUsers(v []*User) { + e.Edges.Users = v + e.Edges.loadedTypes[0] = true +} + +// SetGroups manually set the edge as loaded state. +func (e *StoragePolicy) SetGroups(v []*Group) { + e.Edges.Groups = v + e.Edges.loadedTypes[1] = true +} + +// SetFiles manually set the edge as loaded state. +func (e *StoragePolicy) SetFiles(v []*File) { + e.Edges.Files = v + e.Edges.loadedTypes[2] = true +} + +// SetEntities manually set the edge as loaded state. +func (e *StoragePolicy) SetEntities(v []*Entity) { + e.Edges.Entities = v + e.Edges.loadedTypes[3] = true +} + +// SetNode manually set the edge as loaded state. +func (e *StoragePolicy) SetNode(v *Node) { + e.Edges.Node = v + e.Edges.loadedTypes[4] = true +} + +// StoragePolicies is a parsable slice of StoragePolicy. +type StoragePolicies []*StoragePolicy diff --git a/ent/storagepolicy/storagepolicy.go b/ent/storagepolicy/storagepolicy.go new file mode 100644 index 00000000..0c66ba90 --- /dev/null +++ b/ent/storagepolicy/storagepolicy.go @@ -0,0 +1,320 @@ +// Code generated by ent, DO NOT EDIT. + +package storagepolicy + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +const ( + // Label holds the string label denoting the storagepolicy type in the database. + Label = "storage_policy" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldServer holds the string denoting the server field in the database. + FieldServer = "server" + // FieldBucketName holds the string denoting the bucket_name field in the database. + FieldBucketName = "bucket_name" + // FieldIsPrivate holds the string denoting the is_private field in the database. + FieldIsPrivate = "is_private" + // FieldAccessKey holds the string denoting the access_key field in the database. + FieldAccessKey = "access_key" + // FieldSecretKey holds the string denoting the secret_key field in the database. + FieldSecretKey = "secret_key" + // FieldMaxSize holds the string denoting the max_size field in the database. + FieldMaxSize = "max_size" + // FieldDirNameRule holds the string denoting the dir_name_rule field in the database. + FieldDirNameRule = "dir_name_rule" + // FieldFileNameRule holds the string denoting the file_name_rule field in the database. + FieldFileNameRule = "file_name_rule" + // FieldSettings holds the string denoting the settings field in the database. + FieldSettings = "settings" + // FieldNodeID holds the string denoting the node_id field in the database. + FieldNodeID = "node_id" + // EdgeUsers holds the string denoting the users edge name in mutations. + EdgeUsers = "users" + // EdgeGroups holds the string denoting the groups edge name in mutations. + EdgeGroups = "groups" + // EdgeFiles holds the string denoting the files edge name in mutations. + EdgeFiles = "files" + // EdgeEntities holds the string denoting the entities edge name in mutations. + EdgeEntities = "entities" + // EdgeNode holds the string denoting the node edge name in mutations. + EdgeNode = "node" + // Table holds the table name of the storagepolicy in the database. + Table = "storage_policies" + // UsersTable is the table that holds the users relation/edge. + UsersTable = "users" + // UsersInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UsersInverseTable = "users" + // UsersColumn is the table column denoting the users relation/edge. + UsersColumn = "storage_policy_users" + // GroupsTable is the table that holds the groups relation/edge. + GroupsTable = "groups" + // GroupsInverseTable is the table name for the Group entity. + // It exists in this package in order to avoid circular dependency with the "group" package. + GroupsInverseTable = "groups" + // GroupsColumn is the table column denoting the groups relation/edge. + GroupsColumn = "storage_policy_id" + // FilesTable is the table that holds the files relation/edge. + FilesTable = "files" + // FilesInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FilesInverseTable = "files" + // FilesColumn is the table column denoting the files relation/edge. + FilesColumn = "storage_policy_files" + // EntitiesTable is the table that holds the entities relation/edge. + EntitiesTable = "entities" + // EntitiesInverseTable is the table name for the Entity entity. + // It exists in this package in order to avoid circular dependency with the "entity" package. + EntitiesInverseTable = "entities" + // EntitiesColumn is the table column denoting the entities relation/edge. + EntitiesColumn = "storage_policy_entities" + // NodeTable is the table that holds the node relation/edge. + NodeTable = "storage_policies" + // NodeInverseTable is the table name for the Node entity. + // It exists in this package in order to avoid circular dependency with the "node" package. + NodeInverseTable = "nodes" + // NodeColumn is the table column denoting the node relation/edge. + NodeColumn = "node_id" +) + +// Columns holds all SQL columns for storagepolicy fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldName, + FieldType, + FieldServer, + FieldBucketName, + FieldIsPrivate, + FieldAccessKey, + FieldSecretKey, + FieldMaxSize, + FieldDirNameRule, + FieldFileNameRule, + FieldSettings, + FieldNodeID, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultSettings holds the default value on creation for the "settings" field. + DefaultSettings *types.PolicySetting +) + +// OrderOption defines the ordering options for the StoragePolicy queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByServer orders the results by the server field. +func ByServer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldServer, opts...).ToFunc() +} + +// ByBucketName orders the results by the bucket_name field. +func ByBucketName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBucketName, opts...).ToFunc() +} + +// ByIsPrivate orders the results by the is_private field. +func ByIsPrivate(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsPrivate, opts...).ToFunc() +} + +// ByAccessKey orders the results by the access_key field. +func ByAccessKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccessKey, opts...).ToFunc() +} + +// BySecretKey orders the results by the secret_key field. +func BySecretKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSecretKey, opts...).ToFunc() +} + +// ByMaxSize orders the results by the max_size field. +func ByMaxSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMaxSize, opts...).ToFunc() +} + +// ByDirNameRule orders the results by the dir_name_rule field. +func ByDirNameRule(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDirNameRule, opts...).ToFunc() +} + +// ByFileNameRule orders the results by the file_name_rule field. +func ByFileNameRule(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFileNameRule, opts...).ToFunc() +} + +// ByNodeID orders the results by the node_id field. +func ByNodeID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNodeID, opts...).ToFunc() +} + +// ByUsersCount orders the results by users count. +func ByUsersCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsersStep(), opts...) + } +} + +// ByUsers orders the results by users terms. +func ByUsers(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsersStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByGroupsCount orders the results by groups count. +func ByGroupsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newGroupsStep(), opts...) + } +} + +// ByGroups orders the results by groups terms. +func ByGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newGroupsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByFilesCount orders the results by files count. +func ByFilesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newFilesStep(), opts...) + } +} + +// ByFiles orders the results by files terms. +func ByFiles(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFilesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEntitiesCount orders the results by entities count. +func ByEntitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEntitiesStep(), opts...) + } +} + +// ByEntities orders the results by entities terms. +func ByEntities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEntitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByNodeField orders the results by node field. +func ByNodeField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newNodeStep(), sql.OrderByField(field, opts...)) + } +} +func newUsersStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsersInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsersTable, UsersColumn), + ) +} +func newGroupsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(GroupsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, GroupsTable, GroupsColumn), + ) +} +func newFilesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FilesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, FilesTable, FilesColumn), + ) +} +func newEntitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EntitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EntitiesTable, EntitiesColumn), + ) +} +func newNodeStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(NodeInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, NodeTable, NodeColumn), + ) +} diff --git a/ent/storagepolicy/where.go b/ent/storagepolicy/where.go new file mode 100644 index 00000000..13146b51 --- /dev/null +++ b/ent/storagepolicy/where.go @@ -0,0 +1,1076 @@ +// Code generated by ent, DO NOT EDIT. + +package storagepolicy + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldName, v)) +} + +// Type applies equality check predicate on the "type" field. It's identical to TypeEQ. +func Type(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldType, v)) +} + +// Server applies equality check predicate on the "server" field. It's identical to ServerEQ. +func Server(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldServer, v)) +} + +// BucketName applies equality check predicate on the "bucket_name" field. It's identical to BucketNameEQ. +func BucketName(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldBucketName, v)) +} + +// IsPrivate applies equality check predicate on the "is_private" field. It's identical to IsPrivateEQ. +func IsPrivate(v bool) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldIsPrivate, v)) +} + +// AccessKey applies equality check predicate on the "access_key" field. It's identical to AccessKeyEQ. +func AccessKey(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldAccessKey, v)) +} + +// SecretKey applies equality check predicate on the "secret_key" field. It's identical to SecretKeyEQ. +func SecretKey(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldSecretKey, v)) +} + +// MaxSize applies equality check predicate on the "max_size" field. It's identical to MaxSizeEQ. +func MaxSize(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldMaxSize, v)) +} + +// DirNameRule applies equality check predicate on the "dir_name_rule" field. It's identical to DirNameRuleEQ. +func DirNameRule(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldDirNameRule, v)) +} + +// FileNameRule applies equality check predicate on the "file_name_rule" field. It's identical to FileNameRuleEQ. +func FileNameRule(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldFileNameRule, v)) +} + +// NodeID applies equality check predicate on the "node_id" field. It's identical to NodeIDEQ. +func NodeID(v int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldNodeID, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldDeletedAt)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldName, v)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldType, vs...)) +} + +// TypeGT applies the GT predicate on the "type" field. +func TypeGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldType, v)) +} + +// TypeGTE applies the GTE predicate on the "type" field. +func TypeGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldType, v)) +} + +// TypeLT applies the LT predicate on the "type" field. +func TypeLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldType, v)) +} + +// TypeLTE applies the LTE predicate on the "type" field. +func TypeLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldType, v)) +} + +// TypeContains applies the Contains predicate on the "type" field. +func TypeContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldType, v)) +} + +// TypeHasPrefix applies the HasPrefix predicate on the "type" field. +func TypeHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldType, v)) +} + +// TypeHasSuffix applies the HasSuffix predicate on the "type" field. +func TypeHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldType, v)) +} + +// TypeEqualFold applies the EqualFold predicate on the "type" field. +func TypeEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldType, v)) +} + +// TypeContainsFold applies the ContainsFold predicate on the "type" field. +func TypeContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldType, v)) +} + +// ServerEQ applies the EQ predicate on the "server" field. +func ServerEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldServer, v)) +} + +// ServerNEQ applies the NEQ predicate on the "server" field. +func ServerNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldServer, v)) +} + +// ServerIn applies the In predicate on the "server" field. +func ServerIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldServer, vs...)) +} + +// ServerNotIn applies the NotIn predicate on the "server" field. +func ServerNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldServer, vs...)) +} + +// ServerGT applies the GT predicate on the "server" field. +func ServerGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldServer, v)) +} + +// ServerGTE applies the GTE predicate on the "server" field. +func ServerGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldServer, v)) +} + +// ServerLT applies the LT predicate on the "server" field. +func ServerLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldServer, v)) +} + +// ServerLTE applies the LTE predicate on the "server" field. +func ServerLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldServer, v)) +} + +// ServerContains applies the Contains predicate on the "server" field. +func ServerContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldServer, v)) +} + +// ServerHasPrefix applies the HasPrefix predicate on the "server" field. +func ServerHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldServer, v)) +} + +// ServerHasSuffix applies the HasSuffix predicate on the "server" field. +func ServerHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldServer, v)) +} + +// ServerIsNil applies the IsNil predicate on the "server" field. +func ServerIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldServer)) +} + +// ServerNotNil applies the NotNil predicate on the "server" field. +func ServerNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldServer)) +} + +// ServerEqualFold applies the EqualFold predicate on the "server" field. +func ServerEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldServer, v)) +} + +// ServerContainsFold applies the ContainsFold predicate on the "server" field. +func ServerContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldServer, v)) +} + +// BucketNameEQ applies the EQ predicate on the "bucket_name" field. +func BucketNameEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldBucketName, v)) +} + +// BucketNameNEQ applies the NEQ predicate on the "bucket_name" field. +func BucketNameNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldBucketName, v)) +} + +// BucketNameIn applies the In predicate on the "bucket_name" field. +func BucketNameIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldBucketName, vs...)) +} + +// BucketNameNotIn applies the NotIn predicate on the "bucket_name" field. +func BucketNameNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldBucketName, vs...)) +} + +// BucketNameGT applies the GT predicate on the "bucket_name" field. +func BucketNameGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldBucketName, v)) +} + +// BucketNameGTE applies the GTE predicate on the "bucket_name" field. +func BucketNameGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldBucketName, v)) +} + +// BucketNameLT applies the LT predicate on the "bucket_name" field. +func BucketNameLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldBucketName, v)) +} + +// BucketNameLTE applies the LTE predicate on the "bucket_name" field. +func BucketNameLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldBucketName, v)) +} + +// BucketNameContains applies the Contains predicate on the "bucket_name" field. +func BucketNameContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldBucketName, v)) +} + +// BucketNameHasPrefix applies the HasPrefix predicate on the "bucket_name" field. +func BucketNameHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldBucketName, v)) +} + +// BucketNameHasSuffix applies the HasSuffix predicate on the "bucket_name" field. +func BucketNameHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldBucketName, v)) +} + +// BucketNameIsNil applies the IsNil predicate on the "bucket_name" field. +func BucketNameIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldBucketName)) +} + +// BucketNameNotNil applies the NotNil predicate on the "bucket_name" field. +func BucketNameNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldBucketName)) +} + +// BucketNameEqualFold applies the EqualFold predicate on the "bucket_name" field. +func BucketNameEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldBucketName, v)) +} + +// BucketNameContainsFold applies the ContainsFold predicate on the "bucket_name" field. +func BucketNameContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldBucketName, v)) +} + +// IsPrivateEQ applies the EQ predicate on the "is_private" field. +func IsPrivateEQ(v bool) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldIsPrivate, v)) +} + +// IsPrivateNEQ applies the NEQ predicate on the "is_private" field. +func IsPrivateNEQ(v bool) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldIsPrivate, v)) +} + +// IsPrivateIsNil applies the IsNil predicate on the "is_private" field. +func IsPrivateIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldIsPrivate)) +} + +// IsPrivateNotNil applies the NotNil predicate on the "is_private" field. +func IsPrivateNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldIsPrivate)) +} + +// AccessKeyEQ applies the EQ predicate on the "access_key" field. +func AccessKeyEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldAccessKey, v)) +} + +// AccessKeyNEQ applies the NEQ predicate on the "access_key" field. +func AccessKeyNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldAccessKey, v)) +} + +// AccessKeyIn applies the In predicate on the "access_key" field. +func AccessKeyIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldAccessKey, vs...)) +} + +// AccessKeyNotIn applies the NotIn predicate on the "access_key" field. +func AccessKeyNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldAccessKey, vs...)) +} + +// AccessKeyGT applies the GT predicate on the "access_key" field. +func AccessKeyGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldAccessKey, v)) +} + +// AccessKeyGTE applies the GTE predicate on the "access_key" field. +func AccessKeyGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldAccessKey, v)) +} + +// AccessKeyLT applies the LT predicate on the "access_key" field. +func AccessKeyLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldAccessKey, v)) +} + +// AccessKeyLTE applies the LTE predicate on the "access_key" field. +func AccessKeyLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldAccessKey, v)) +} + +// AccessKeyContains applies the Contains predicate on the "access_key" field. +func AccessKeyContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldAccessKey, v)) +} + +// AccessKeyHasPrefix applies the HasPrefix predicate on the "access_key" field. +func AccessKeyHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldAccessKey, v)) +} + +// AccessKeyHasSuffix applies the HasSuffix predicate on the "access_key" field. +func AccessKeyHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldAccessKey, v)) +} + +// AccessKeyIsNil applies the IsNil predicate on the "access_key" field. +func AccessKeyIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldAccessKey)) +} + +// AccessKeyNotNil applies the NotNil predicate on the "access_key" field. +func AccessKeyNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldAccessKey)) +} + +// AccessKeyEqualFold applies the EqualFold predicate on the "access_key" field. +func AccessKeyEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldAccessKey, v)) +} + +// AccessKeyContainsFold applies the ContainsFold predicate on the "access_key" field. +func AccessKeyContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldAccessKey, v)) +} + +// SecretKeyEQ applies the EQ predicate on the "secret_key" field. +func SecretKeyEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldSecretKey, v)) +} + +// SecretKeyNEQ applies the NEQ predicate on the "secret_key" field. +func SecretKeyNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldSecretKey, v)) +} + +// SecretKeyIn applies the In predicate on the "secret_key" field. +func SecretKeyIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldSecretKey, vs...)) +} + +// SecretKeyNotIn applies the NotIn predicate on the "secret_key" field. +func SecretKeyNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldSecretKey, vs...)) +} + +// SecretKeyGT applies the GT predicate on the "secret_key" field. +func SecretKeyGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldSecretKey, v)) +} + +// SecretKeyGTE applies the GTE predicate on the "secret_key" field. +func SecretKeyGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldSecretKey, v)) +} + +// SecretKeyLT applies the LT predicate on the "secret_key" field. +func SecretKeyLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldSecretKey, v)) +} + +// SecretKeyLTE applies the LTE predicate on the "secret_key" field. +func SecretKeyLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldSecretKey, v)) +} + +// SecretKeyContains applies the Contains predicate on the "secret_key" field. +func SecretKeyContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldSecretKey, v)) +} + +// SecretKeyHasPrefix applies the HasPrefix predicate on the "secret_key" field. +func SecretKeyHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldSecretKey, v)) +} + +// SecretKeyHasSuffix applies the HasSuffix predicate on the "secret_key" field. +func SecretKeyHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldSecretKey, v)) +} + +// SecretKeyIsNil applies the IsNil predicate on the "secret_key" field. +func SecretKeyIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldSecretKey)) +} + +// SecretKeyNotNil applies the NotNil predicate on the "secret_key" field. +func SecretKeyNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldSecretKey)) +} + +// SecretKeyEqualFold applies the EqualFold predicate on the "secret_key" field. +func SecretKeyEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldSecretKey, v)) +} + +// SecretKeyContainsFold applies the ContainsFold predicate on the "secret_key" field. +func SecretKeyContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldSecretKey, v)) +} + +// MaxSizeEQ applies the EQ predicate on the "max_size" field. +func MaxSizeEQ(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldMaxSize, v)) +} + +// MaxSizeNEQ applies the NEQ predicate on the "max_size" field. +func MaxSizeNEQ(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldMaxSize, v)) +} + +// MaxSizeIn applies the In predicate on the "max_size" field. +func MaxSizeIn(vs ...int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldMaxSize, vs...)) +} + +// MaxSizeNotIn applies the NotIn predicate on the "max_size" field. +func MaxSizeNotIn(vs ...int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldMaxSize, vs...)) +} + +// MaxSizeGT applies the GT predicate on the "max_size" field. +func MaxSizeGT(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldMaxSize, v)) +} + +// MaxSizeGTE applies the GTE predicate on the "max_size" field. +func MaxSizeGTE(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldMaxSize, v)) +} + +// MaxSizeLT applies the LT predicate on the "max_size" field. +func MaxSizeLT(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldMaxSize, v)) +} + +// MaxSizeLTE applies the LTE predicate on the "max_size" field. +func MaxSizeLTE(v int64) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldMaxSize, v)) +} + +// MaxSizeIsNil applies the IsNil predicate on the "max_size" field. +func MaxSizeIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldMaxSize)) +} + +// MaxSizeNotNil applies the NotNil predicate on the "max_size" field. +func MaxSizeNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldMaxSize)) +} + +// DirNameRuleEQ applies the EQ predicate on the "dir_name_rule" field. +func DirNameRuleEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldDirNameRule, v)) +} + +// DirNameRuleNEQ applies the NEQ predicate on the "dir_name_rule" field. +func DirNameRuleNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldDirNameRule, v)) +} + +// DirNameRuleIn applies the In predicate on the "dir_name_rule" field. +func DirNameRuleIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldDirNameRule, vs...)) +} + +// DirNameRuleNotIn applies the NotIn predicate on the "dir_name_rule" field. +func DirNameRuleNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldDirNameRule, vs...)) +} + +// DirNameRuleGT applies the GT predicate on the "dir_name_rule" field. +func DirNameRuleGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldDirNameRule, v)) +} + +// DirNameRuleGTE applies the GTE predicate on the "dir_name_rule" field. +func DirNameRuleGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldDirNameRule, v)) +} + +// DirNameRuleLT applies the LT predicate on the "dir_name_rule" field. +func DirNameRuleLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldDirNameRule, v)) +} + +// DirNameRuleLTE applies the LTE predicate on the "dir_name_rule" field. +func DirNameRuleLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldDirNameRule, v)) +} + +// DirNameRuleContains applies the Contains predicate on the "dir_name_rule" field. +func DirNameRuleContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldDirNameRule, v)) +} + +// DirNameRuleHasPrefix applies the HasPrefix predicate on the "dir_name_rule" field. +func DirNameRuleHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldDirNameRule, v)) +} + +// DirNameRuleHasSuffix applies the HasSuffix predicate on the "dir_name_rule" field. +func DirNameRuleHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldDirNameRule, v)) +} + +// DirNameRuleIsNil applies the IsNil predicate on the "dir_name_rule" field. +func DirNameRuleIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldDirNameRule)) +} + +// DirNameRuleNotNil applies the NotNil predicate on the "dir_name_rule" field. +func DirNameRuleNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldDirNameRule)) +} + +// DirNameRuleEqualFold applies the EqualFold predicate on the "dir_name_rule" field. +func DirNameRuleEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldDirNameRule, v)) +} + +// DirNameRuleContainsFold applies the ContainsFold predicate on the "dir_name_rule" field. +func DirNameRuleContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldDirNameRule, v)) +} + +// FileNameRuleEQ applies the EQ predicate on the "file_name_rule" field. +func FileNameRuleEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldFileNameRule, v)) +} + +// FileNameRuleNEQ applies the NEQ predicate on the "file_name_rule" field. +func FileNameRuleNEQ(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldFileNameRule, v)) +} + +// FileNameRuleIn applies the In predicate on the "file_name_rule" field. +func FileNameRuleIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldFileNameRule, vs...)) +} + +// FileNameRuleNotIn applies the NotIn predicate on the "file_name_rule" field. +func FileNameRuleNotIn(vs ...string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldFileNameRule, vs...)) +} + +// FileNameRuleGT applies the GT predicate on the "file_name_rule" field. +func FileNameRuleGT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGT(FieldFileNameRule, v)) +} + +// FileNameRuleGTE applies the GTE predicate on the "file_name_rule" field. +func FileNameRuleGTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldGTE(FieldFileNameRule, v)) +} + +// FileNameRuleLT applies the LT predicate on the "file_name_rule" field. +func FileNameRuleLT(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLT(FieldFileNameRule, v)) +} + +// FileNameRuleLTE applies the LTE predicate on the "file_name_rule" field. +func FileNameRuleLTE(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldLTE(FieldFileNameRule, v)) +} + +// FileNameRuleContains applies the Contains predicate on the "file_name_rule" field. +func FileNameRuleContains(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContains(FieldFileNameRule, v)) +} + +// FileNameRuleHasPrefix applies the HasPrefix predicate on the "file_name_rule" field. +func FileNameRuleHasPrefix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasPrefix(FieldFileNameRule, v)) +} + +// FileNameRuleHasSuffix applies the HasSuffix predicate on the "file_name_rule" field. +func FileNameRuleHasSuffix(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldHasSuffix(FieldFileNameRule, v)) +} + +// FileNameRuleIsNil applies the IsNil predicate on the "file_name_rule" field. +func FileNameRuleIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldFileNameRule)) +} + +// FileNameRuleNotNil applies the NotNil predicate on the "file_name_rule" field. +func FileNameRuleNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldFileNameRule)) +} + +// FileNameRuleEqualFold applies the EqualFold predicate on the "file_name_rule" field. +func FileNameRuleEqualFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEqualFold(FieldFileNameRule, v)) +} + +// FileNameRuleContainsFold applies the ContainsFold predicate on the "file_name_rule" field. +func FileNameRuleContainsFold(v string) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldContainsFold(FieldFileNameRule, v)) +} + +// SettingsIsNil applies the IsNil predicate on the "settings" field. +func SettingsIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldSettings)) +} + +// SettingsNotNil applies the NotNil predicate on the "settings" field. +func SettingsNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldSettings)) +} + +// NodeIDEQ applies the EQ predicate on the "node_id" field. +func NodeIDEQ(v int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldEQ(FieldNodeID, v)) +} + +// NodeIDNEQ applies the NEQ predicate on the "node_id" field. +func NodeIDNEQ(v int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNEQ(FieldNodeID, v)) +} + +// NodeIDIn applies the In predicate on the "node_id" field. +func NodeIDIn(vs ...int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIn(FieldNodeID, vs...)) +} + +// NodeIDNotIn applies the NotIn predicate on the "node_id" field. +func NodeIDNotIn(vs ...int) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotIn(FieldNodeID, vs...)) +} + +// NodeIDIsNil applies the IsNil predicate on the "node_id" field. +func NodeIDIsNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldIsNull(FieldNodeID)) +} + +// NodeIDNotNil applies the NotNil predicate on the "node_id" field. +func NodeIDNotNil() predicate.StoragePolicy { + return predicate.StoragePolicy(sql.FieldNotNull(FieldNodeID)) +} + +// HasUsers applies the HasEdge predicate on the "users" edge. +func HasUsers() predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsersTable, UsersColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsersWith applies the HasEdge predicate on the "users" edge with a given conditions (other predicates). +func HasUsersWith(preds ...predicate.User) predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := newUsersStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasGroups applies the HasEdge predicate on the "groups" edge. +func HasGroups() predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, GroupsTable, GroupsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasGroupsWith applies the HasEdge predicate on the "groups" edge with a given conditions (other predicates). +func HasGroupsWith(preds ...predicate.Group) predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := newGroupsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasFiles applies the HasEdge predicate on the "files" edge. +func HasFiles() predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, FilesTable, FilesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFilesWith applies the HasEdge predicate on the "files" edge with a given conditions (other predicates). +func HasFilesWith(preds ...predicate.File) predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := newFilesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasEntities applies the HasEdge predicate on the "entities" edge. +func HasEntities() predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EntitiesTable, EntitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasEntitiesWith applies the HasEdge predicate on the "entities" edge with a given conditions (other predicates). +func HasEntitiesWith(preds ...predicate.Entity) predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := newEntitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasNode applies the HasEdge predicate on the "node" edge. +func HasNode() predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, NodeTable, NodeColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasNodeWith applies the HasEdge predicate on the "node" edge with a given conditions (other predicates). +func HasNodeWith(preds ...predicate.Node) predicate.StoragePolicy { + return predicate.StoragePolicy(func(s *sql.Selector) { + step := newNodeStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.StoragePolicy) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.StoragePolicy) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.StoragePolicy) predicate.StoragePolicy { + return predicate.StoragePolicy(sql.NotPredicates(p)) +} diff --git a/ent/storagepolicy_create.go b/ent/storagepolicy_create.go new file mode 100644 index 00000000..32b54ac8 --- /dev/null +++ b/ent/storagepolicy_create.go @@ -0,0 +1,1659 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// StoragePolicyCreate is the builder for creating a StoragePolicy entity. +type StoragePolicyCreate struct { + config + mutation *StoragePolicyMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (spc *StoragePolicyCreate) SetCreatedAt(t time.Time) *StoragePolicyCreate { + spc.mutation.SetCreatedAt(t) + return spc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableCreatedAt(t *time.Time) *StoragePolicyCreate { + if t != nil { + spc.SetCreatedAt(*t) + } + return spc +} + +// SetUpdatedAt sets the "updated_at" field. +func (spc *StoragePolicyCreate) SetUpdatedAt(t time.Time) *StoragePolicyCreate { + spc.mutation.SetUpdatedAt(t) + return spc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableUpdatedAt(t *time.Time) *StoragePolicyCreate { + if t != nil { + spc.SetUpdatedAt(*t) + } + return spc +} + +// SetDeletedAt sets the "deleted_at" field. +func (spc *StoragePolicyCreate) SetDeletedAt(t time.Time) *StoragePolicyCreate { + spc.mutation.SetDeletedAt(t) + return spc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableDeletedAt(t *time.Time) *StoragePolicyCreate { + if t != nil { + spc.SetDeletedAt(*t) + } + return spc +} + +// SetName sets the "name" field. +func (spc *StoragePolicyCreate) SetName(s string) *StoragePolicyCreate { + spc.mutation.SetName(s) + return spc +} + +// SetType sets the "type" field. +func (spc *StoragePolicyCreate) SetType(s string) *StoragePolicyCreate { + spc.mutation.SetType(s) + return spc +} + +// SetServer sets the "server" field. +func (spc *StoragePolicyCreate) SetServer(s string) *StoragePolicyCreate { + spc.mutation.SetServer(s) + return spc +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableServer(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetServer(*s) + } + return spc +} + +// SetBucketName sets the "bucket_name" field. +func (spc *StoragePolicyCreate) SetBucketName(s string) *StoragePolicyCreate { + spc.mutation.SetBucketName(s) + return spc +} + +// SetNillableBucketName sets the "bucket_name" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableBucketName(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetBucketName(*s) + } + return spc +} + +// SetIsPrivate sets the "is_private" field. +func (spc *StoragePolicyCreate) SetIsPrivate(b bool) *StoragePolicyCreate { + spc.mutation.SetIsPrivate(b) + return spc +} + +// SetNillableIsPrivate sets the "is_private" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableIsPrivate(b *bool) *StoragePolicyCreate { + if b != nil { + spc.SetIsPrivate(*b) + } + return spc +} + +// SetAccessKey sets the "access_key" field. +func (spc *StoragePolicyCreate) SetAccessKey(s string) *StoragePolicyCreate { + spc.mutation.SetAccessKey(s) + return spc +} + +// SetNillableAccessKey sets the "access_key" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableAccessKey(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetAccessKey(*s) + } + return spc +} + +// SetSecretKey sets the "secret_key" field. +func (spc *StoragePolicyCreate) SetSecretKey(s string) *StoragePolicyCreate { + spc.mutation.SetSecretKey(s) + return spc +} + +// SetNillableSecretKey sets the "secret_key" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableSecretKey(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetSecretKey(*s) + } + return spc +} + +// SetMaxSize sets the "max_size" field. +func (spc *StoragePolicyCreate) SetMaxSize(i int64) *StoragePolicyCreate { + spc.mutation.SetMaxSize(i) + return spc +} + +// SetNillableMaxSize sets the "max_size" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableMaxSize(i *int64) *StoragePolicyCreate { + if i != nil { + spc.SetMaxSize(*i) + } + return spc +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (spc *StoragePolicyCreate) SetDirNameRule(s string) *StoragePolicyCreate { + spc.mutation.SetDirNameRule(s) + return spc +} + +// SetNillableDirNameRule sets the "dir_name_rule" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableDirNameRule(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetDirNameRule(*s) + } + return spc +} + +// SetFileNameRule sets the "file_name_rule" field. +func (spc *StoragePolicyCreate) SetFileNameRule(s string) *StoragePolicyCreate { + spc.mutation.SetFileNameRule(s) + return spc +} + +// SetNillableFileNameRule sets the "file_name_rule" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableFileNameRule(s *string) *StoragePolicyCreate { + if s != nil { + spc.SetFileNameRule(*s) + } + return spc +} + +// SetSettings sets the "settings" field. +func (spc *StoragePolicyCreate) SetSettings(ts *types.PolicySetting) *StoragePolicyCreate { + spc.mutation.SetSettings(ts) + return spc +} + +// SetNodeID sets the "node_id" field. +func (spc *StoragePolicyCreate) SetNodeID(i int) *StoragePolicyCreate { + spc.mutation.SetNodeID(i) + return spc +} + +// SetNillableNodeID sets the "node_id" field if the given value is not nil. +func (spc *StoragePolicyCreate) SetNillableNodeID(i *int) *StoragePolicyCreate { + if i != nil { + spc.SetNodeID(*i) + } + return spc +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (spc *StoragePolicyCreate) AddUserIDs(ids ...int) *StoragePolicyCreate { + spc.mutation.AddUserIDs(ids...) + return spc +} + +// AddUsers adds the "users" edges to the User entity. +func (spc *StoragePolicyCreate) AddUsers(u ...*User) *StoragePolicyCreate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return spc.AddUserIDs(ids...) +} + +// AddGroupIDs adds the "groups" edge to the Group entity by IDs. +func (spc *StoragePolicyCreate) AddGroupIDs(ids ...int) *StoragePolicyCreate { + spc.mutation.AddGroupIDs(ids...) + return spc +} + +// AddGroups adds the "groups" edges to the Group entity. +func (spc *StoragePolicyCreate) AddGroups(g ...*Group) *StoragePolicyCreate { + ids := make([]int, len(g)) + for i := range g { + ids[i] = g[i].ID + } + return spc.AddGroupIDs(ids...) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (spc *StoragePolicyCreate) AddFileIDs(ids ...int) *StoragePolicyCreate { + spc.mutation.AddFileIDs(ids...) + return spc +} + +// AddFiles adds the "files" edges to the File entity. +func (spc *StoragePolicyCreate) AddFiles(f ...*File) *StoragePolicyCreate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return spc.AddFileIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (spc *StoragePolicyCreate) AddEntityIDs(ids ...int) *StoragePolicyCreate { + spc.mutation.AddEntityIDs(ids...) + return spc +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (spc *StoragePolicyCreate) AddEntities(e ...*Entity) *StoragePolicyCreate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return spc.AddEntityIDs(ids...) +} + +// SetNode sets the "node" edge to the Node entity. +func (spc *StoragePolicyCreate) SetNode(n *Node) *StoragePolicyCreate { + return spc.SetNodeID(n.ID) +} + +// Mutation returns the StoragePolicyMutation object of the builder. +func (spc *StoragePolicyCreate) Mutation() *StoragePolicyMutation { + return spc.mutation +} + +// Save creates the StoragePolicy in the database. +func (spc *StoragePolicyCreate) Save(ctx context.Context) (*StoragePolicy, error) { + if err := spc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, spc.sqlSave, spc.mutation, spc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (spc *StoragePolicyCreate) SaveX(ctx context.Context) *StoragePolicy { + v, err := spc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (spc *StoragePolicyCreate) Exec(ctx context.Context) error { + _, err := spc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (spc *StoragePolicyCreate) ExecX(ctx context.Context) { + if err := spc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (spc *StoragePolicyCreate) defaults() error { + if _, ok := spc.mutation.CreatedAt(); !ok { + if storagepolicy.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized storagepolicy.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := storagepolicy.DefaultCreatedAt() + spc.mutation.SetCreatedAt(v) + } + if _, ok := spc.mutation.UpdatedAt(); !ok { + if storagepolicy.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized storagepolicy.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := storagepolicy.DefaultUpdatedAt() + spc.mutation.SetUpdatedAt(v) + } + if _, ok := spc.mutation.Settings(); !ok { + v := storagepolicy.DefaultSettings + spc.mutation.SetSettings(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (spc *StoragePolicyCreate) check() error { + if _, ok := spc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "StoragePolicy.created_at"`)} + } + if _, ok := spc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "StoragePolicy.updated_at"`)} + } + if _, ok := spc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "StoragePolicy.name"`)} + } + if _, ok := spc.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "StoragePolicy.type"`)} + } + return nil +} + +func (spc *StoragePolicyCreate) sqlSave(ctx context.Context) (*StoragePolicy, error) { + if err := spc.check(); err != nil { + return nil, err + } + _node, _spec := spc.createSpec() + if err := sqlgraph.CreateNode(ctx, spc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + spc.mutation.id = &_node.ID + spc.mutation.done = true + return _node, nil +} + +func (spc *StoragePolicyCreate) createSpec() (*StoragePolicy, *sqlgraph.CreateSpec) { + var ( + _node = &StoragePolicy{config: spc.config} + _spec = sqlgraph.NewCreateSpec(storagepolicy.Table, sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt)) + ) + + if id, ok := spc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = spc.conflict + if value, ok := spc.mutation.CreatedAt(); ok { + _spec.SetField(storagepolicy.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := spc.mutation.UpdatedAt(); ok { + _spec.SetField(storagepolicy.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := spc.mutation.DeletedAt(); ok { + _spec.SetField(storagepolicy.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := spc.mutation.Name(); ok { + _spec.SetField(storagepolicy.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := spc.mutation.GetType(); ok { + _spec.SetField(storagepolicy.FieldType, field.TypeString, value) + _node.Type = value + } + if value, ok := spc.mutation.Server(); ok { + _spec.SetField(storagepolicy.FieldServer, field.TypeString, value) + _node.Server = value + } + if value, ok := spc.mutation.BucketName(); ok { + _spec.SetField(storagepolicy.FieldBucketName, field.TypeString, value) + _node.BucketName = value + } + if value, ok := spc.mutation.IsPrivate(); ok { + _spec.SetField(storagepolicy.FieldIsPrivate, field.TypeBool, value) + _node.IsPrivate = value + } + if value, ok := spc.mutation.AccessKey(); ok { + _spec.SetField(storagepolicy.FieldAccessKey, field.TypeString, value) + _node.AccessKey = value + } + if value, ok := spc.mutation.SecretKey(); ok { + _spec.SetField(storagepolicy.FieldSecretKey, field.TypeString, value) + _node.SecretKey = value + } + if value, ok := spc.mutation.MaxSize(); ok { + _spec.SetField(storagepolicy.FieldMaxSize, field.TypeInt64, value) + _node.MaxSize = value + } + if value, ok := spc.mutation.DirNameRule(); ok { + _spec.SetField(storagepolicy.FieldDirNameRule, field.TypeString, value) + _node.DirNameRule = value + } + if value, ok := spc.mutation.FileNameRule(); ok { + _spec.SetField(storagepolicy.FieldFileNameRule, field.TypeString, value) + _node.FileNameRule = value + } + if value, ok := spc.mutation.Settings(); ok { + _spec.SetField(storagepolicy.FieldSettings, field.TypeJSON, value) + _node.Settings = value + } + if nodes := spc.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := spc.mutation.GroupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := spc.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := spc.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := spc.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: storagepolicy.NodeTable, + Columns: []string{storagepolicy.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.NodeID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.StoragePolicy.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.StoragePolicyUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (spc *StoragePolicyCreate) OnConflict(opts ...sql.ConflictOption) *StoragePolicyUpsertOne { + spc.conflict = opts + return &StoragePolicyUpsertOne{ + create: spc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (spc *StoragePolicyCreate) OnConflictColumns(columns ...string) *StoragePolicyUpsertOne { + spc.conflict = append(spc.conflict, sql.ConflictColumns(columns...)) + return &StoragePolicyUpsertOne{ + create: spc, + } +} + +type ( + // StoragePolicyUpsertOne is the builder for "upsert"-ing + // one StoragePolicy node. + StoragePolicyUpsertOne struct { + create *StoragePolicyCreate + } + + // StoragePolicyUpsert is the "OnConflict" setter. + StoragePolicyUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *StoragePolicyUpsert) SetUpdatedAt(v time.Time) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateUpdatedAt() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *StoragePolicyUpsert) SetDeletedAt(v time.Time) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateDeletedAt() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *StoragePolicyUpsert) ClearDeletedAt() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldDeletedAt) + return u +} + +// SetName sets the "name" field. +func (u *StoragePolicyUpsert) SetName(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateName() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldName) + return u +} + +// SetType sets the "type" field. +func (u *StoragePolicyUpsert) SetType(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldType, v) + return u +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateType() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldType) + return u +} + +// SetServer sets the "server" field. +func (u *StoragePolicyUpsert) SetServer(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldServer, v) + return u +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateServer() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldServer) + return u +} + +// ClearServer clears the value of the "server" field. +func (u *StoragePolicyUpsert) ClearServer() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldServer) + return u +} + +// SetBucketName sets the "bucket_name" field. +func (u *StoragePolicyUpsert) SetBucketName(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldBucketName, v) + return u +} + +// UpdateBucketName sets the "bucket_name" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateBucketName() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldBucketName) + return u +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (u *StoragePolicyUpsert) ClearBucketName() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldBucketName) + return u +} + +// SetIsPrivate sets the "is_private" field. +func (u *StoragePolicyUpsert) SetIsPrivate(v bool) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldIsPrivate, v) + return u +} + +// UpdateIsPrivate sets the "is_private" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateIsPrivate() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldIsPrivate) + return u +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (u *StoragePolicyUpsert) ClearIsPrivate() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldIsPrivate) + return u +} + +// SetAccessKey sets the "access_key" field. +func (u *StoragePolicyUpsert) SetAccessKey(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldAccessKey, v) + return u +} + +// UpdateAccessKey sets the "access_key" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateAccessKey() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldAccessKey) + return u +} + +// ClearAccessKey clears the value of the "access_key" field. +func (u *StoragePolicyUpsert) ClearAccessKey() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldAccessKey) + return u +} + +// SetSecretKey sets the "secret_key" field. +func (u *StoragePolicyUpsert) SetSecretKey(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldSecretKey, v) + return u +} + +// UpdateSecretKey sets the "secret_key" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateSecretKey() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldSecretKey) + return u +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (u *StoragePolicyUpsert) ClearSecretKey() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldSecretKey) + return u +} + +// SetMaxSize sets the "max_size" field. +func (u *StoragePolicyUpsert) SetMaxSize(v int64) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldMaxSize, v) + return u +} + +// UpdateMaxSize sets the "max_size" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateMaxSize() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldMaxSize) + return u +} + +// AddMaxSize adds v to the "max_size" field. +func (u *StoragePolicyUpsert) AddMaxSize(v int64) *StoragePolicyUpsert { + u.Add(storagepolicy.FieldMaxSize, v) + return u +} + +// ClearMaxSize clears the value of the "max_size" field. +func (u *StoragePolicyUpsert) ClearMaxSize() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldMaxSize) + return u +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (u *StoragePolicyUpsert) SetDirNameRule(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldDirNameRule, v) + return u +} + +// UpdateDirNameRule sets the "dir_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateDirNameRule() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldDirNameRule) + return u +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (u *StoragePolicyUpsert) ClearDirNameRule() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldDirNameRule) + return u +} + +// SetFileNameRule sets the "file_name_rule" field. +func (u *StoragePolicyUpsert) SetFileNameRule(v string) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldFileNameRule, v) + return u +} + +// UpdateFileNameRule sets the "file_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateFileNameRule() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldFileNameRule) + return u +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (u *StoragePolicyUpsert) ClearFileNameRule() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldFileNameRule) + return u +} + +// SetSettings sets the "settings" field. +func (u *StoragePolicyUpsert) SetSettings(v *types.PolicySetting) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldSettings, v) + return u +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateSettings() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldSettings) + return u +} + +// ClearSettings clears the value of the "settings" field. +func (u *StoragePolicyUpsert) ClearSettings() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldSettings) + return u +} + +// SetNodeID sets the "node_id" field. +func (u *StoragePolicyUpsert) SetNodeID(v int) *StoragePolicyUpsert { + u.Set(storagepolicy.FieldNodeID, v) + return u +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *StoragePolicyUpsert) UpdateNodeID() *StoragePolicyUpsert { + u.SetExcluded(storagepolicy.FieldNodeID) + return u +} + +// ClearNodeID clears the value of the "node_id" field. +func (u *StoragePolicyUpsert) ClearNodeID() *StoragePolicyUpsert { + u.SetNull(storagepolicy.FieldNodeID) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *StoragePolicyUpsertOne) UpdateNewValues() *StoragePolicyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(storagepolicy.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *StoragePolicyUpsertOne) Ignore() *StoragePolicyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *StoragePolicyUpsertOne) DoNothing() *StoragePolicyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the StoragePolicyCreate.OnConflict +// documentation for more info. +func (u *StoragePolicyUpsertOne) Update(set func(*StoragePolicyUpsert)) *StoragePolicyUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&StoragePolicyUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *StoragePolicyUpsertOne) SetUpdatedAt(v time.Time) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateUpdatedAt() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *StoragePolicyUpsertOne) SetDeletedAt(v time.Time) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateDeletedAt() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *StoragePolicyUpsertOne) ClearDeletedAt() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *StoragePolicyUpsertOne) SetName(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateName() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateName() + }) +} + +// SetType sets the "type" field. +func (u *StoragePolicyUpsertOne) SetType(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateType() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateType() + }) +} + +// SetServer sets the "server" field. +func (u *StoragePolicyUpsertOne) SetServer(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetServer(v) + }) +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateServer() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateServer() + }) +} + +// ClearServer clears the value of the "server" field. +func (u *StoragePolicyUpsertOne) ClearServer() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearServer() + }) +} + +// SetBucketName sets the "bucket_name" field. +func (u *StoragePolicyUpsertOne) SetBucketName(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetBucketName(v) + }) +} + +// UpdateBucketName sets the "bucket_name" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateBucketName() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateBucketName() + }) +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (u *StoragePolicyUpsertOne) ClearBucketName() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearBucketName() + }) +} + +// SetIsPrivate sets the "is_private" field. +func (u *StoragePolicyUpsertOne) SetIsPrivate(v bool) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetIsPrivate(v) + }) +} + +// UpdateIsPrivate sets the "is_private" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateIsPrivate() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateIsPrivate() + }) +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (u *StoragePolicyUpsertOne) ClearIsPrivate() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearIsPrivate() + }) +} + +// SetAccessKey sets the "access_key" field. +func (u *StoragePolicyUpsertOne) SetAccessKey(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetAccessKey(v) + }) +} + +// UpdateAccessKey sets the "access_key" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateAccessKey() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateAccessKey() + }) +} + +// ClearAccessKey clears the value of the "access_key" field. +func (u *StoragePolicyUpsertOne) ClearAccessKey() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearAccessKey() + }) +} + +// SetSecretKey sets the "secret_key" field. +func (u *StoragePolicyUpsertOne) SetSecretKey(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetSecretKey(v) + }) +} + +// UpdateSecretKey sets the "secret_key" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateSecretKey() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateSecretKey() + }) +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (u *StoragePolicyUpsertOne) ClearSecretKey() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearSecretKey() + }) +} + +// SetMaxSize sets the "max_size" field. +func (u *StoragePolicyUpsertOne) SetMaxSize(v int64) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetMaxSize(v) + }) +} + +// AddMaxSize adds v to the "max_size" field. +func (u *StoragePolicyUpsertOne) AddMaxSize(v int64) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.AddMaxSize(v) + }) +} + +// UpdateMaxSize sets the "max_size" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateMaxSize() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateMaxSize() + }) +} + +// ClearMaxSize clears the value of the "max_size" field. +func (u *StoragePolicyUpsertOne) ClearMaxSize() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearMaxSize() + }) +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (u *StoragePolicyUpsertOne) SetDirNameRule(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetDirNameRule(v) + }) +} + +// UpdateDirNameRule sets the "dir_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateDirNameRule() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateDirNameRule() + }) +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (u *StoragePolicyUpsertOne) ClearDirNameRule() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearDirNameRule() + }) +} + +// SetFileNameRule sets the "file_name_rule" field. +func (u *StoragePolicyUpsertOne) SetFileNameRule(v string) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetFileNameRule(v) + }) +} + +// UpdateFileNameRule sets the "file_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateFileNameRule() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateFileNameRule() + }) +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (u *StoragePolicyUpsertOne) ClearFileNameRule() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearFileNameRule() + }) +} + +// SetSettings sets the "settings" field. +func (u *StoragePolicyUpsertOne) SetSettings(v *types.PolicySetting) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateSettings() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *StoragePolicyUpsertOne) ClearSettings() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearSettings() + }) +} + +// SetNodeID sets the "node_id" field. +func (u *StoragePolicyUpsertOne) SetNodeID(v int) *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetNodeID(v) + }) +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *StoragePolicyUpsertOne) UpdateNodeID() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateNodeID() + }) +} + +// ClearNodeID clears the value of the "node_id" field. +func (u *StoragePolicyUpsertOne) ClearNodeID() *StoragePolicyUpsertOne { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearNodeID() + }) +} + +// Exec executes the query. +func (u *StoragePolicyUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for StoragePolicyCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *StoragePolicyUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *StoragePolicyUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *StoragePolicyUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *StoragePolicyCreate) SetRawID(t int) *StoragePolicyCreate { + m.mutation.SetRawID(t) + return m +} + +// StoragePolicyCreateBulk is the builder for creating many StoragePolicy entities in bulk. +type StoragePolicyCreateBulk struct { + config + err error + builders []*StoragePolicyCreate + conflict []sql.ConflictOption +} + +// Save creates the StoragePolicy entities in the database. +func (spcb *StoragePolicyCreateBulk) Save(ctx context.Context) ([]*StoragePolicy, error) { + if spcb.err != nil { + return nil, spcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(spcb.builders)) + nodes := make([]*StoragePolicy, len(spcb.builders)) + mutators := make([]Mutator, len(spcb.builders)) + for i := range spcb.builders { + func(i int, root context.Context) { + builder := spcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*StoragePolicyMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, spcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = spcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, spcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, spcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (spcb *StoragePolicyCreateBulk) SaveX(ctx context.Context) []*StoragePolicy { + v, err := spcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (spcb *StoragePolicyCreateBulk) Exec(ctx context.Context) error { + _, err := spcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (spcb *StoragePolicyCreateBulk) ExecX(ctx context.Context) { + if err := spcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.StoragePolicy.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.StoragePolicyUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (spcb *StoragePolicyCreateBulk) OnConflict(opts ...sql.ConflictOption) *StoragePolicyUpsertBulk { + spcb.conflict = opts + return &StoragePolicyUpsertBulk{ + create: spcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (spcb *StoragePolicyCreateBulk) OnConflictColumns(columns ...string) *StoragePolicyUpsertBulk { + spcb.conflict = append(spcb.conflict, sql.ConflictColumns(columns...)) + return &StoragePolicyUpsertBulk{ + create: spcb, + } +} + +// StoragePolicyUpsertBulk is the builder for "upsert"-ing +// a bulk of StoragePolicy nodes. +type StoragePolicyUpsertBulk struct { + create *StoragePolicyCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *StoragePolicyUpsertBulk) UpdateNewValues() *StoragePolicyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(storagepolicy.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.StoragePolicy.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *StoragePolicyUpsertBulk) Ignore() *StoragePolicyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *StoragePolicyUpsertBulk) DoNothing() *StoragePolicyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the StoragePolicyCreateBulk.OnConflict +// documentation for more info. +func (u *StoragePolicyUpsertBulk) Update(set func(*StoragePolicyUpsert)) *StoragePolicyUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&StoragePolicyUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *StoragePolicyUpsertBulk) SetUpdatedAt(v time.Time) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateUpdatedAt() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *StoragePolicyUpsertBulk) SetDeletedAt(v time.Time) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateDeletedAt() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *StoragePolicyUpsertBulk) ClearDeletedAt() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearDeletedAt() + }) +} + +// SetName sets the "name" field. +func (u *StoragePolicyUpsertBulk) SetName(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateName() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateName() + }) +} + +// SetType sets the "type" field. +func (u *StoragePolicyUpsertBulk) SetType(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateType() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateType() + }) +} + +// SetServer sets the "server" field. +func (u *StoragePolicyUpsertBulk) SetServer(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetServer(v) + }) +} + +// UpdateServer sets the "server" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateServer() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateServer() + }) +} + +// ClearServer clears the value of the "server" field. +func (u *StoragePolicyUpsertBulk) ClearServer() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearServer() + }) +} + +// SetBucketName sets the "bucket_name" field. +func (u *StoragePolicyUpsertBulk) SetBucketName(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetBucketName(v) + }) +} + +// UpdateBucketName sets the "bucket_name" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateBucketName() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateBucketName() + }) +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (u *StoragePolicyUpsertBulk) ClearBucketName() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearBucketName() + }) +} + +// SetIsPrivate sets the "is_private" field. +func (u *StoragePolicyUpsertBulk) SetIsPrivate(v bool) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetIsPrivate(v) + }) +} + +// UpdateIsPrivate sets the "is_private" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateIsPrivate() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateIsPrivate() + }) +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (u *StoragePolicyUpsertBulk) ClearIsPrivate() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearIsPrivate() + }) +} + +// SetAccessKey sets the "access_key" field. +func (u *StoragePolicyUpsertBulk) SetAccessKey(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetAccessKey(v) + }) +} + +// UpdateAccessKey sets the "access_key" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateAccessKey() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateAccessKey() + }) +} + +// ClearAccessKey clears the value of the "access_key" field. +func (u *StoragePolicyUpsertBulk) ClearAccessKey() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearAccessKey() + }) +} + +// SetSecretKey sets the "secret_key" field. +func (u *StoragePolicyUpsertBulk) SetSecretKey(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetSecretKey(v) + }) +} + +// UpdateSecretKey sets the "secret_key" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateSecretKey() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateSecretKey() + }) +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (u *StoragePolicyUpsertBulk) ClearSecretKey() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearSecretKey() + }) +} + +// SetMaxSize sets the "max_size" field. +func (u *StoragePolicyUpsertBulk) SetMaxSize(v int64) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetMaxSize(v) + }) +} + +// AddMaxSize adds v to the "max_size" field. +func (u *StoragePolicyUpsertBulk) AddMaxSize(v int64) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.AddMaxSize(v) + }) +} + +// UpdateMaxSize sets the "max_size" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateMaxSize() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateMaxSize() + }) +} + +// ClearMaxSize clears the value of the "max_size" field. +func (u *StoragePolicyUpsertBulk) ClearMaxSize() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearMaxSize() + }) +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (u *StoragePolicyUpsertBulk) SetDirNameRule(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetDirNameRule(v) + }) +} + +// UpdateDirNameRule sets the "dir_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateDirNameRule() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateDirNameRule() + }) +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (u *StoragePolicyUpsertBulk) ClearDirNameRule() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearDirNameRule() + }) +} + +// SetFileNameRule sets the "file_name_rule" field. +func (u *StoragePolicyUpsertBulk) SetFileNameRule(v string) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetFileNameRule(v) + }) +} + +// UpdateFileNameRule sets the "file_name_rule" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateFileNameRule() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateFileNameRule() + }) +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (u *StoragePolicyUpsertBulk) ClearFileNameRule() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearFileNameRule() + }) +} + +// SetSettings sets the "settings" field. +func (u *StoragePolicyUpsertBulk) SetSettings(v *types.PolicySetting) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateSettings() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *StoragePolicyUpsertBulk) ClearSettings() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearSettings() + }) +} + +// SetNodeID sets the "node_id" field. +func (u *StoragePolicyUpsertBulk) SetNodeID(v int) *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.SetNodeID(v) + }) +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *StoragePolicyUpsertBulk) UpdateNodeID() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.UpdateNodeID() + }) +} + +// ClearNodeID clears the value of the "node_id" field. +func (u *StoragePolicyUpsertBulk) ClearNodeID() *StoragePolicyUpsertBulk { + return u.Update(func(s *StoragePolicyUpsert) { + s.ClearNodeID() + }) +} + +// Exec executes the query. +func (u *StoragePolicyUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the StoragePolicyCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for StoragePolicyCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *StoragePolicyUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/storagepolicy_delete.go b/ent/storagepolicy_delete.go new file mode 100644 index 00000000..4e64a360 --- /dev/null +++ b/ent/storagepolicy_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" +) + +// StoragePolicyDelete is the builder for deleting a StoragePolicy entity. +type StoragePolicyDelete struct { + config + hooks []Hook + mutation *StoragePolicyMutation +} + +// Where appends a list predicates to the StoragePolicyDelete builder. +func (spd *StoragePolicyDelete) Where(ps ...predicate.StoragePolicy) *StoragePolicyDelete { + spd.mutation.Where(ps...) + return spd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (spd *StoragePolicyDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, spd.sqlExec, spd.mutation, spd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (spd *StoragePolicyDelete) ExecX(ctx context.Context) int { + n, err := spd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (spd *StoragePolicyDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(storagepolicy.Table, sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt)) + if ps := spd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, spd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + spd.mutation.done = true + return affected, err +} + +// StoragePolicyDeleteOne is the builder for deleting a single StoragePolicy entity. +type StoragePolicyDeleteOne struct { + spd *StoragePolicyDelete +} + +// Where appends a list predicates to the StoragePolicyDelete builder. +func (spdo *StoragePolicyDeleteOne) Where(ps ...predicate.StoragePolicy) *StoragePolicyDeleteOne { + spdo.spd.mutation.Where(ps...) + return spdo +} + +// Exec executes the deletion query. +func (spdo *StoragePolicyDeleteOne) Exec(ctx context.Context) error { + n, err := spdo.spd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{storagepolicy.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (spdo *StoragePolicyDeleteOne) ExecX(ctx context.Context) { + if err := spdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/storagepolicy_query.go b/ent/storagepolicy_query.go new file mode 100644 index 00000000..a4570e1d --- /dev/null +++ b/ent/storagepolicy_query.go @@ -0,0 +1,903 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// StoragePolicyQuery is the builder for querying StoragePolicy entities. +type StoragePolicyQuery struct { + config + ctx *QueryContext + order []storagepolicy.OrderOption + inters []Interceptor + predicates []predicate.StoragePolicy + withUsers *UserQuery + withGroups *GroupQuery + withFiles *FileQuery + withEntities *EntityQuery + withNode *NodeQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the StoragePolicyQuery builder. +func (spq *StoragePolicyQuery) Where(ps ...predicate.StoragePolicy) *StoragePolicyQuery { + spq.predicates = append(spq.predicates, ps...) + return spq +} + +// Limit the number of records to be returned by this query. +func (spq *StoragePolicyQuery) Limit(limit int) *StoragePolicyQuery { + spq.ctx.Limit = &limit + return spq +} + +// Offset to start from. +func (spq *StoragePolicyQuery) Offset(offset int) *StoragePolicyQuery { + spq.ctx.Offset = &offset + return spq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (spq *StoragePolicyQuery) Unique(unique bool) *StoragePolicyQuery { + spq.ctx.Unique = &unique + return spq +} + +// Order specifies how the records should be ordered. +func (spq *StoragePolicyQuery) Order(o ...storagepolicy.OrderOption) *StoragePolicyQuery { + spq.order = append(spq.order, o...) + return spq +} + +// QueryUsers chains the current query on the "users" edge. +func (spq *StoragePolicyQuery) QueryUsers() *UserQuery { + query := (&UserClient{config: spq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := spq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.UsersTable, storagepolicy.UsersColumn), + ) + fromU = sqlgraph.SetNeighbors(spq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryGroups chains the current query on the "groups" edge. +func (spq *StoragePolicyQuery) QueryGroups() *GroupQuery { + query := (&GroupClient{config: spq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := spq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, selector), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.GroupsTable, storagepolicy.GroupsColumn), + ) + fromU = sqlgraph.SetNeighbors(spq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryFiles chains the current query on the "files" edge. +func (spq *StoragePolicyQuery) QueryFiles() *FileQuery { + query := (&FileClient{config: spq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := spq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.FilesTable, storagepolicy.FilesColumn), + ) + fromU = sqlgraph.SetNeighbors(spq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryEntities chains the current query on the "entities" edge. +func (spq *StoragePolicyQuery) QueryEntities() *EntityQuery { + query := (&EntityClient{config: spq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := spq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, selector), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, storagepolicy.EntitiesTable, storagepolicy.EntitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(spq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryNode chains the current query on the "node" edge. +func (spq *StoragePolicyQuery) QueryNode() *NodeQuery { + query := (&NodeClient{config: spq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := spq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(storagepolicy.Table, storagepolicy.FieldID, selector), + sqlgraph.To(node.Table, node.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, storagepolicy.NodeTable, storagepolicy.NodeColumn), + ) + fromU = sqlgraph.SetNeighbors(spq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first StoragePolicy entity from the query. +// Returns a *NotFoundError when no StoragePolicy was found. +func (spq *StoragePolicyQuery) First(ctx context.Context) (*StoragePolicy, error) { + nodes, err := spq.Limit(1).All(setContextOp(ctx, spq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{storagepolicy.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (spq *StoragePolicyQuery) FirstX(ctx context.Context) *StoragePolicy { + node, err := spq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first StoragePolicy ID from the query. +// Returns a *NotFoundError when no StoragePolicy ID was found. +func (spq *StoragePolicyQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = spq.Limit(1).IDs(setContextOp(ctx, spq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{storagepolicy.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (spq *StoragePolicyQuery) FirstIDX(ctx context.Context) int { + id, err := spq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single StoragePolicy entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one StoragePolicy entity is found. +// Returns a *NotFoundError when no StoragePolicy entities are found. +func (spq *StoragePolicyQuery) Only(ctx context.Context) (*StoragePolicy, error) { + nodes, err := spq.Limit(2).All(setContextOp(ctx, spq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{storagepolicy.Label} + default: + return nil, &NotSingularError{storagepolicy.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (spq *StoragePolicyQuery) OnlyX(ctx context.Context) *StoragePolicy { + node, err := spq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only StoragePolicy ID in the query. +// Returns a *NotSingularError when more than one StoragePolicy ID is found. +// Returns a *NotFoundError when no entities are found. +func (spq *StoragePolicyQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = spq.Limit(2).IDs(setContextOp(ctx, spq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{storagepolicy.Label} + default: + err = &NotSingularError{storagepolicy.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (spq *StoragePolicyQuery) OnlyIDX(ctx context.Context) int { + id, err := spq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of StoragePolicies. +func (spq *StoragePolicyQuery) All(ctx context.Context) ([]*StoragePolicy, error) { + ctx = setContextOp(ctx, spq.ctx, "All") + if err := spq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*StoragePolicy, *StoragePolicyQuery]() + return withInterceptors[[]*StoragePolicy](ctx, spq, qr, spq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (spq *StoragePolicyQuery) AllX(ctx context.Context) []*StoragePolicy { + nodes, err := spq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of StoragePolicy IDs. +func (spq *StoragePolicyQuery) IDs(ctx context.Context) (ids []int, err error) { + if spq.ctx.Unique == nil && spq.path != nil { + spq.Unique(true) + } + ctx = setContextOp(ctx, spq.ctx, "IDs") + if err = spq.Select(storagepolicy.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (spq *StoragePolicyQuery) IDsX(ctx context.Context) []int { + ids, err := spq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (spq *StoragePolicyQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, spq.ctx, "Count") + if err := spq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, spq, querierCount[*StoragePolicyQuery](), spq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (spq *StoragePolicyQuery) CountX(ctx context.Context) int { + count, err := spq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (spq *StoragePolicyQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, spq.ctx, "Exist") + switch _, err := spq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (spq *StoragePolicyQuery) ExistX(ctx context.Context) bool { + exist, err := spq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the StoragePolicyQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (spq *StoragePolicyQuery) Clone() *StoragePolicyQuery { + if spq == nil { + return nil + } + return &StoragePolicyQuery{ + config: spq.config, + ctx: spq.ctx.Clone(), + order: append([]storagepolicy.OrderOption{}, spq.order...), + inters: append([]Interceptor{}, spq.inters...), + predicates: append([]predicate.StoragePolicy{}, spq.predicates...), + withUsers: spq.withUsers.Clone(), + withGroups: spq.withGroups.Clone(), + withFiles: spq.withFiles.Clone(), + withEntities: spq.withEntities.Clone(), + withNode: spq.withNode.Clone(), + // clone intermediate query. + sql: spq.sql.Clone(), + path: spq.path, + } +} + +// WithUsers tells the query-builder to eager-load the nodes that are connected to +// the "users" edge. The optional arguments are used to configure the query builder of the edge. +func (spq *StoragePolicyQuery) WithUsers(opts ...func(*UserQuery)) *StoragePolicyQuery { + query := (&UserClient{config: spq.config}).Query() + for _, opt := range opts { + opt(query) + } + spq.withUsers = query + return spq +} + +// WithGroups tells the query-builder to eager-load the nodes that are connected to +// the "groups" edge. The optional arguments are used to configure the query builder of the edge. +func (spq *StoragePolicyQuery) WithGroups(opts ...func(*GroupQuery)) *StoragePolicyQuery { + query := (&GroupClient{config: spq.config}).Query() + for _, opt := range opts { + opt(query) + } + spq.withGroups = query + return spq +} + +// WithFiles tells the query-builder to eager-load the nodes that are connected to +// the "files" edge. The optional arguments are used to configure the query builder of the edge. +func (spq *StoragePolicyQuery) WithFiles(opts ...func(*FileQuery)) *StoragePolicyQuery { + query := (&FileClient{config: spq.config}).Query() + for _, opt := range opts { + opt(query) + } + spq.withFiles = query + return spq +} + +// WithEntities tells the query-builder to eager-load the nodes that are connected to +// the "entities" edge. The optional arguments are used to configure the query builder of the edge. +func (spq *StoragePolicyQuery) WithEntities(opts ...func(*EntityQuery)) *StoragePolicyQuery { + query := (&EntityClient{config: spq.config}).Query() + for _, opt := range opts { + opt(query) + } + spq.withEntities = query + return spq +} + +// WithNode tells the query-builder to eager-load the nodes that are connected to +// the "node" edge. The optional arguments are used to configure the query builder of the edge. +func (spq *StoragePolicyQuery) WithNode(opts ...func(*NodeQuery)) *StoragePolicyQuery { + query := (&NodeClient{config: spq.config}).Query() + for _, opt := range opts { + opt(query) + } + spq.withNode = query + return spq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.StoragePolicy.Query(). +// GroupBy(storagepolicy.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (spq *StoragePolicyQuery) GroupBy(field string, fields ...string) *StoragePolicyGroupBy { + spq.ctx.Fields = append([]string{field}, fields...) + grbuild := &StoragePolicyGroupBy{build: spq} + grbuild.flds = &spq.ctx.Fields + grbuild.label = storagepolicy.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.StoragePolicy.Query(). +// Select(storagepolicy.FieldCreatedAt). +// Scan(ctx, &v) +func (spq *StoragePolicyQuery) Select(fields ...string) *StoragePolicySelect { + spq.ctx.Fields = append(spq.ctx.Fields, fields...) + sbuild := &StoragePolicySelect{StoragePolicyQuery: spq} + sbuild.label = storagepolicy.Label + sbuild.flds, sbuild.scan = &spq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a StoragePolicySelect configured with the given aggregations. +func (spq *StoragePolicyQuery) Aggregate(fns ...AggregateFunc) *StoragePolicySelect { + return spq.Select().Aggregate(fns...) +} + +func (spq *StoragePolicyQuery) prepareQuery(ctx context.Context) error { + for _, inter := range spq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, spq); err != nil { + return err + } + } + } + for _, f := range spq.ctx.Fields { + if !storagepolicy.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if spq.path != nil { + prev, err := spq.path(ctx) + if err != nil { + return err + } + spq.sql = prev + } + return nil +} + +func (spq *StoragePolicyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*StoragePolicy, error) { + var ( + nodes = []*StoragePolicy{} + _spec = spq.querySpec() + loadedTypes = [5]bool{ + spq.withUsers != nil, + spq.withGroups != nil, + spq.withFiles != nil, + spq.withEntities != nil, + spq.withNode != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*StoragePolicy).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &StoragePolicy{config: spq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, spq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := spq.withUsers; query != nil { + if err := spq.loadUsers(ctx, query, nodes, + func(n *StoragePolicy) { n.Edges.Users = []*User{} }, + func(n *StoragePolicy, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { + return nil, err + } + } + if query := spq.withGroups; query != nil { + if err := spq.loadGroups(ctx, query, nodes, + func(n *StoragePolicy) { n.Edges.Groups = []*Group{} }, + func(n *StoragePolicy, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { + return nil, err + } + } + if query := spq.withFiles; query != nil { + if err := spq.loadFiles(ctx, query, nodes, + func(n *StoragePolicy) { n.Edges.Files = []*File{} }, + func(n *StoragePolicy, e *File) { n.Edges.Files = append(n.Edges.Files, e) }); err != nil { + return nil, err + } + } + if query := spq.withEntities; query != nil { + if err := spq.loadEntities(ctx, query, nodes, + func(n *StoragePolicy) { n.Edges.Entities = []*Entity{} }, + func(n *StoragePolicy, e *Entity) { n.Edges.Entities = append(n.Edges.Entities, e) }); err != nil { + return nil, err + } + } + if query := spq.withNode; query != nil { + if err := spq.loadNode(ctx, query, nodes, nil, + func(n *StoragePolicy, e *Node) { n.Edges.Node = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (spq *StoragePolicyQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*StoragePolicy, init func(*StoragePolicy), assign func(*StoragePolicy, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*StoragePolicy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(storagepolicy.UsersColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.storage_policy_users + if fk == nil { + return fmt.Errorf(`foreign-key "storage_policy_users" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "storage_policy_users" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (spq *StoragePolicyQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*StoragePolicy, init func(*StoragePolicy), assign func(*StoragePolicy, *Group)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*StoragePolicy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(group.FieldStoragePolicyID) + } + query.Where(predicate.Group(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(storagepolicy.GroupsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.StoragePolicyID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "storage_policy_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (spq *StoragePolicyQuery) loadFiles(ctx context.Context, query *FileQuery, nodes []*StoragePolicy, init func(*StoragePolicy), assign func(*StoragePolicy, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*StoragePolicy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(file.FieldStoragePolicyFiles) + } + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(storagepolicy.FilesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.StoragePolicyFiles + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "storage_policy_files" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (spq *StoragePolicyQuery) loadEntities(ctx context.Context, query *EntityQuery, nodes []*StoragePolicy, init func(*StoragePolicy), assign func(*StoragePolicy, *Entity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*StoragePolicy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(entity.FieldStoragePolicyEntities) + } + query.Where(predicate.Entity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(storagepolicy.EntitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.StoragePolicyEntities + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "storage_policy_entities" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (spq *StoragePolicyQuery) loadNode(ctx context.Context, query *NodeQuery, nodes []*StoragePolicy, init func(*StoragePolicy), assign func(*StoragePolicy, *Node)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*StoragePolicy) + for i := range nodes { + fk := nodes[i].NodeID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (spq *StoragePolicyQuery) sqlCount(ctx context.Context) (int, error) { + _spec := spq.querySpec() + _spec.Node.Columns = spq.ctx.Fields + if len(spq.ctx.Fields) > 0 { + _spec.Unique = spq.ctx.Unique != nil && *spq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, spq.driver, _spec) +} + +func (spq *StoragePolicyQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(storagepolicy.Table, storagepolicy.Columns, sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt)) + _spec.From = spq.sql + if unique := spq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if spq.path != nil { + _spec.Unique = true + } + if fields := spq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, storagepolicy.FieldID) + for i := range fields { + if fields[i] != storagepolicy.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if spq.withNode != nil { + _spec.Node.AddColumnOnce(storagepolicy.FieldNodeID) + } + } + if ps := spq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := spq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := spq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := spq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (spq *StoragePolicyQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(spq.driver.Dialect()) + t1 := builder.Table(storagepolicy.Table) + columns := spq.ctx.Fields + if len(columns) == 0 { + columns = storagepolicy.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if spq.sql != nil { + selector = spq.sql + selector.Select(selector.Columns(columns...)...) + } + if spq.ctx.Unique != nil && *spq.ctx.Unique { + selector.Distinct() + } + for _, p := range spq.predicates { + p(selector) + } + for _, p := range spq.order { + p(selector) + } + if offset := spq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := spq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// StoragePolicyGroupBy is the group-by builder for StoragePolicy entities. +type StoragePolicyGroupBy struct { + selector + build *StoragePolicyQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (spgb *StoragePolicyGroupBy) Aggregate(fns ...AggregateFunc) *StoragePolicyGroupBy { + spgb.fns = append(spgb.fns, fns...) + return spgb +} + +// Scan applies the selector query and scans the result into the given value. +func (spgb *StoragePolicyGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, spgb.build.ctx, "GroupBy") + if err := spgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*StoragePolicyQuery, *StoragePolicyGroupBy](ctx, spgb.build, spgb, spgb.build.inters, v) +} + +func (spgb *StoragePolicyGroupBy) sqlScan(ctx context.Context, root *StoragePolicyQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(spgb.fns)) + for _, fn := range spgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*spgb.flds)+len(spgb.fns)) + for _, f := range *spgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*spgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := spgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// StoragePolicySelect is the builder for selecting fields of StoragePolicy entities. +type StoragePolicySelect struct { + *StoragePolicyQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (sps *StoragePolicySelect) Aggregate(fns ...AggregateFunc) *StoragePolicySelect { + sps.fns = append(sps.fns, fns...) + return sps +} + +// Scan applies the selector query and scans the result into the given value. +func (sps *StoragePolicySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, sps.ctx, "Select") + if err := sps.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*StoragePolicyQuery, *StoragePolicySelect](ctx, sps.StoragePolicyQuery, sps, sps.inters, v) +} + +func (sps *StoragePolicySelect) sqlScan(ctx context.Context, root *StoragePolicyQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(sps.fns)) + for _, fn := range sps.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*sps.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := sps.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/storagepolicy_update.go b/ent/storagepolicy_update.go new file mode 100644 index 00000000..b679d693 --- /dev/null +++ b/ent/storagepolicy_update.go @@ -0,0 +1,1590 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// StoragePolicyUpdate is the builder for updating StoragePolicy entities. +type StoragePolicyUpdate struct { + config + hooks []Hook + mutation *StoragePolicyMutation +} + +// Where appends a list predicates to the StoragePolicyUpdate builder. +func (spu *StoragePolicyUpdate) Where(ps ...predicate.StoragePolicy) *StoragePolicyUpdate { + spu.mutation.Where(ps...) + return spu +} + +// SetUpdatedAt sets the "updated_at" field. +func (spu *StoragePolicyUpdate) SetUpdatedAt(t time.Time) *StoragePolicyUpdate { + spu.mutation.SetUpdatedAt(t) + return spu +} + +// SetDeletedAt sets the "deleted_at" field. +func (spu *StoragePolicyUpdate) SetDeletedAt(t time.Time) *StoragePolicyUpdate { + spu.mutation.SetDeletedAt(t) + return spu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableDeletedAt(t *time.Time) *StoragePolicyUpdate { + if t != nil { + spu.SetDeletedAt(*t) + } + return spu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (spu *StoragePolicyUpdate) ClearDeletedAt() *StoragePolicyUpdate { + spu.mutation.ClearDeletedAt() + return spu +} + +// SetName sets the "name" field. +func (spu *StoragePolicyUpdate) SetName(s string) *StoragePolicyUpdate { + spu.mutation.SetName(s) + return spu +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableName(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetName(*s) + } + return spu +} + +// SetType sets the "type" field. +func (spu *StoragePolicyUpdate) SetType(s string) *StoragePolicyUpdate { + spu.mutation.SetType(s) + return spu +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableType(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetType(*s) + } + return spu +} + +// SetServer sets the "server" field. +func (spu *StoragePolicyUpdate) SetServer(s string) *StoragePolicyUpdate { + spu.mutation.SetServer(s) + return spu +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableServer(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetServer(*s) + } + return spu +} + +// ClearServer clears the value of the "server" field. +func (spu *StoragePolicyUpdate) ClearServer() *StoragePolicyUpdate { + spu.mutation.ClearServer() + return spu +} + +// SetBucketName sets the "bucket_name" field. +func (spu *StoragePolicyUpdate) SetBucketName(s string) *StoragePolicyUpdate { + spu.mutation.SetBucketName(s) + return spu +} + +// SetNillableBucketName sets the "bucket_name" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableBucketName(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetBucketName(*s) + } + return spu +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (spu *StoragePolicyUpdate) ClearBucketName() *StoragePolicyUpdate { + spu.mutation.ClearBucketName() + return spu +} + +// SetIsPrivate sets the "is_private" field. +func (spu *StoragePolicyUpdate) SetIsPrivate(b bool) *StoragePolicyUpdate { + spu.mutation.SetIsPrivate(b) + return spu +} + +// SetNillableIsPrivate sets the "is_private" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableIsPrivate(b *bool) *StoragePolicyUpdate { + if b != nil { + spu.SetIsPrivate(*b) + } + return spu +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (spu *StoragePolicyUpdate) ClearIsPrivate() *StoragePolicyUpdate { + spu.mutation.ClearIsPrivate() + return spu +} + +// SetAccessKey sets the "access_key" field. +func (spu *StoragePolicyUpdate) SetAccessKey(s string) *StoragePolicyUpdate { + spu.mutation.SetAccessKey(s) + return spu +} + +// SetNillableAccessKey sets the "access_key" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableAccessKey(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetAccessKey(*s) + } + return spu +} + +// ClearAccessKey clears the value of the "access_key" field. +func (spu *StoragePolicyUpdate) ClearAccessKey() *StoragePolicyUpdate { + spu.mutation.ClearAccessKey() + return spu +} + +// SetSecretKey sets the "secret_key" field. +func (spu *StoragePolicyUpdate) SetSecretKey(s string) *StoragePolicyUpdate { + spu.mutation.SetSecretKey(s) + return spu +} + +// SetNillableSecretKey sets the "secret_key" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableSecretKey(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetSecretKey(*s) + } + return spu +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (spu *StoragePolicyUpdate) ClearSecretKey() *StoragePolicyUpdate { + spu.mutation.ClearSecretKey() + return spu +} + +// SetMaxSize sets the "max_size" field. +func (spu *StoragePolicyUpdate) SetMaxSize(i int64) *StoragePolicyUpdate { + spu.mutation.ResetMaxSize() + spu.mutation.SetMaxSize(i) + return spu +} + +// SetNillableMaxSize sets the "max_size" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableMaxSize(i *int64) *StoragePolicyUpdate { + if i != nil { + spu.SetMaxSize(*i) + } + return spu +} + +// AddMaxSize adds i to the "max_size" field. +func (spu *StoragePolicyUpdate) AddMaxSize(i int64) *StoragePolicyUpdate { + spu.mutation.AddMaxSize(i) + return spu +} + +// ClearMaxSize clears the value of the "max_size" field. +func (spu *StoragePolicyUpdate) ClearMaxSize() *StoragePolicyUpdate { + spu.mutation.ClearMaxSize() + return spu +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (spu *StoragePolicyUpdate) SetDirNameRule(s string) *StoragePolicyUpdate { + spu.mutation.SetDirNameRule(s) + return spu +} + +// SetNillableDirNameRule sets the "dir_name_rule" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableDirNameRule(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetDirNameRule(*s) + } + return spu +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (spu *StoragePolicyUpdate) ClearDirNameRule() *StoragePolicyUpdate { + spu.mutation.ClearDirNameRule() + return spu +} + +// SetFileNameRule sets the "file_name_rule" field. +func (spu *StoragePolicyUpdate) SetFileNameRule(s string) *StoragePolicyUpdate { + spu.mutation.SetFileNameRule(s) + return spu +} + +// SetNillableFileNameRule sets the "file_name_rule" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableFileNameRule(s *string) *StoragePolicyUpdate { + if s != nil { + spu.SetFileNameRule(*s) + } + return spu +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (spu *StoragePolicyUpdate) ClearFileNameRule() *StoragePolicyUpdate { + spu.mutation.ClearFileNameRule() + return spu +} + +// SetSettings sets the "settings" field. +func (spu *StoragePolicyUpdate) SetSettings(ts *types.PolicySetting) *StoragePolicyUpdate { + spu.mutation.SetSettings(ts) + return spu +} + +// ClearSettings clears the value of the "settings" field. +func (spu *StoragePolicyUpdate) ClearSettings() *StoragePolicyUpdate { + spu.mutation.ClearSettings() + return spu +} + +// SetNodeID sets the "node_id" field. +func (spu *StoragePolicyUpdate) SetNodeID(i int) *StoragePolicyUpdate { + spu.mutation.SetNodeID(i) + return spu +} + +// SetNillableNodeID sets the "node_id" field if the given value is not nil. +func (spu *StoragePolicyUpdate) SetNillableNodeID(i *int) *StoragePolicyUpdate { + if i != nil { + spu.SetNodeID(*i) + } + return spu +} + +// ClearNodeID clears the value of the "node_id" field. +func (spu *StoragePolicyUpdate) ClearNodeID() *StoragePolicyUpdate { + spu.mutation.ClearNodeID() + return spu +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (spu *StoragePolicyUpdate) AddUserIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.AddUserIDs(ids...) + return spu +} + +// AddUsers adds the "users" edges to the User entity. +func (spu *StoragePolicyUpdate) AddUsers(u ...*User) *StoragePolicyUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return spu.AddUserIDs(ids...) +} + +// AddGroupIDs adds the "groups" edge to the Group entity by IDs. +func (spu *StoragePolicyUpdate) AddGroupIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.AddGroupIDs(ids...) + return spu +} + +// AddGroups adds the "groups" edges to the Group entity. +func (spu *StoragePolicyUpdate) AddGroups(g ...*Group) *StoragePolicyUpdate { + ids := make([]int, len(g)) + for i := range g { + ids[i] = g[i].ID + } + return spu.AddGroupIDs(ids...) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (spu *StoragePolicyUpdate) AddFileIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.AddFileIDs(ids...) + return spu +} + +// AddFiles adds the "files" edges to the File entity. +func (spu *StoragePolicyUpdate) AddFiles(f ...*File) *StoragePolicyUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return spu.AddFileIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (spu *StoragePolicyUpdate) AddEntityIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.AddEntityIDs(ids...) + return spu +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (spu *StoragePolicyUpdate) AddEntities(e ...*Entity) *StoragePolicyUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return spu.AddEntityIDs(ids...) +} + +// SetNode sets the "node" edge to the Node entity. +func (spu *StoragePolicyUpdate) SetNode(n *Node) *StoragePolicyUpdate { + return spu.SetNodeID(n.ID) +} + +// Mutation returns the StoragePolicyMutation object of the builder. +func (spu *StoragePolicyUpdate) Mutation() *StoragePolicyMutation { + return spu.mutation +} + +// ClearUsers clears all "users" edges to the User entity. +func (spu *StoragePolicyUpdate) ClearUsers() *StoragePolicyUpdate { + spu.mutation.ClearUsers() + return spu +} + +// RemoveUserIDs removes the "users" edge to User entities by IDs. +func (spu *StoragePolicyUpdate) RemoveUserIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.RemoveUserIDs(ids...) + return spu +} + +// RemoveUsers removes "users" edges to User entities. +func (spu *StoragePolicyUpdate) RemoveUsers(u ...*User) *StoragePolicyUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return spu.RemoveUserIDs(ids...) +} + +// ClearGroups clears all "groups" edges to the Group entity. +func (spu *StoragePolicyUpdate) ClearGroups() *StoragePolicyUpdate { + spu.mutation.ClearGroups() + return spu +} + +// RemoveGroupIDs removes the "groups" edge to Group entities by IDs. +func (spu *StoragePolicyUpdate) RemoveGroupIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.RemoveGroupIDs(ids...) + return spu +} + +// RemoveGroups removes "groups" edges to Group entities. +func (spu *StoragePolicyUpdate) RemoveGroups(g ...*Group) *StoragePolicyUpdate { + ids := make([]int, len(g)) + for i := range g { + ids[i] = g[i].ID + } + return spu.RemoveGroupIDs(ids...) +} + +// ClearFiles clears all "files" edges to the File entity. +func (spu *StoragePolicyUpdate) ClearFiles() *StoragePolicyUpdate { + spu.mutation.ClearFiles() + return spu +} + +// RemoveFileIDs removes the "files" edge to File entities by IDs. +func (spu *StoragePolicyUpdate) RemoveFileIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.RemoveFileIDs(ids...) + return spu +} + +// RemoveFiles removes "files" edges to File entities. +func (spu *StoragePolicyUpdate) RemoveFiles(f ...*File) *StoragePolicyUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return spu.RemoveFileIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (spu *StoragePolicyUpdate) ClearEntities() *StoragePolicyUpdate { + spu.mutation.ClearEntities() + return spu +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (spu *StoragePolicyUpdate) RemoveEntityIDs(ids ...int) *StoragePolicyUpdate { + spu.mutation.RemoveEntityIDs(ids...) + return spu +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (spu *StoragePolicyUpdate) RemoveEntities(e ...*Entity) *StoragePolicyUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return spu.RemoveEntityIDs(ids...) +} + +// ClearNode clears the "node" edge to the Node entity. +func (spu *StoragePolicyUpdate) ClearNode() *StoragePolicyUpdate { + spu.mutation.ClearNode() + return spu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (spu *StoragePolicyUpdate) Save(ctx context.Context) (int, error) { + if err := spu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, spu.sqlSave, spu.mutation, spu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (spu *StoragePolicyUpdate) SaveX(ctx context.Context) int { + affected, err := spu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (spu *StoragePolicyUpdate) Exec(ctx context.Context) error { + _, err := spu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (spu *StoragePolicyUpdate) ExecX(ctx context.Context) { + if err := spu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (spu *StoragePolicyUpdate) defaults() error { + if _, ok := spu.mutation.UpdatedAt(); !ok { + if storagepolicy.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized storagepolicy.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := storagepolicy.UpdateDefaultUpdatedAt() + spu.mutation.SetUpdatedAt(v) + } + return nil +} + +func (spu *StoragePolicyUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(storagepolicy.Table, storagepolicy.Columns, sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt)) + if ps := spu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := spu.mutation.UpdatedAt(); ok { + _spec.SetField(storagepolicy.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := spu.mutation.DeletedAt(); ok { + _spec.SetField(storagepolicy.FieldDeletedAt, field.TypeTime, value) + } + if spu.mutation.DeletedAtCleared() { + _spec.ClearField(storagepolicy.FieldDeletedAt, field.TypeTime) + } + if value, ok := spu.mutation.Name(); ok { + _spec.SetField(storagepolicy.FieldName, field.TypeString, value) + } + if value, ok := spu.mutation.GetType(); ok { + _spec.SetField(storagepolicy.FieldType, field.TypeString, value) + } + if value, ok := spu.mutation.Server(); ok { + _spec.SetField(storagepolicy.FieldServer, field.TypeString, value) + } + if spu.mutation.ServerCleared() { + _spec.ClearField(storagepolicy.FieldServer, field.TypeString) + } + if value, ok := spu.mutation.BucketName(); ok { + _spec.SetField(storagepolicy.FieldBucketName, field.TypeString, value) + } + if spu.mutation.BucketNameCleared() { + _spec.ClearField(storagepolicy.FieldBucketName, field.TypeString) + } + if value, ok := spu.mutation.IsPrivate(); ok { + _spec.SetField(storagepolicy.FieldIsPrivate, field.TypeBool, value) + } + if spu.mutation.IsPrivateCleared() { + _spec.ClearField(storagepolicy.FieldIsPrivate, field.TypeBool) + } + if value, ok := spu.mutation.AccessKey(); ok { + _spec.SetField(storagepolicy.FieldAccessKey, field.TypeString, value) + } + if spu.mutation.AccessKeyCleared() { + _spec.ClearField(storagepolicy.FieldAccessKey, field.TypeString) + } + if value, ok := spu.mutation.SecretKey(); ok { + _spec.SetField(storagepolicy.FieldSecretKey, field.TypeString, value) + } + if spu.mutation.SecretKeyCleared() { + _spec.ClearField(storagepolicy.FieldSecretKey, field.TypeString) + } + if value, ok := spu.mutation.MaxSize(); ok { + _spec.SetField(storagepolicy.FieldMaxSize, field.TypeInt64, value) + } + if value, ok := spu.mutation.AddedMaxSize(); ok { + _spec.AddField(storagepolicy.FieldMaxSize, field.TypeInt64, value) + } + if spu.mutation.MaxSizeCleared() { + _spec.ClearField(storagepolicy.FieldMaxSize, field.TypeInt64) + } + if value, ok := spu.mutation.DirNameRule(); ok { + _spec.SetField(storagepolicy.FieldDirNameRule, field.TypeString, value) + } + if spu.mutation.DirNameRuleCleared() { + _spec.ClearField(storagepolicy.FieldDirNameRule, field.TypeString) + } + if value, ok := spu.mutation.FileNameRule(); ok { + _spec.SetField(storagepolicy.FieldFileNameRule, field.TypeString, value) + } + if spu.mutation.FileNameRuleCleared() { + _spec.ClearField(storagepolicy.FieldFileNameRule, field.TypeString) + } + if value, ok := spu.mutation.Settings(); ok { + _spec.SetField(storagepolicy.FieldSettings, field.TypeJSON, value) + } + if spu.mutation.SettingsCleared() { + _spec.ClearField(storagepolicy.FieldSettings, field.TypeJSON) + } + if spu.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.RemovedUsersIDs(); len(nodes) > 0 && !spu.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spu.mutation.GroupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.RemovedGroupsIDs(); len(nodes) > 0 && !spu.mutation.GroupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.GroupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spu.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.RemovedFilesIDs(); len(nodes) > 0 && !spu.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !spu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spu.mutation.NodeCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: storagepolicy.NodeTable, + Columns: []string{storagepolicy.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spu.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: storagepolicy.NodeTable, + Columns: []string{storagepolicy.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, spu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{storagepolicy.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + spu.mutation.done = true + return n, nil +} + +// StoragePolicyUpdateOne is the builder for updating a single StoragePolicy entity. +type StoragePolicyUpdateOne struct { + config + fields []string + hooks []Hook + mutation *StoragePolicyMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (spuo *StoragePolicyUpdateOne) SetUpdatedAt(t time.Time) *StoragePolicyUpdateOne { + spuo.mutation.SetUpdatedAt(t) + return spuo +} + +// SetDeletedAt sets the "deleted_at" field. +func (spuo *StoragePolicyUpdateOne) SetDeletedAt(t time.Time) *StoragePolicyUpdateOne { + spuo.mutation.SetDeletedAt(t) + return spuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableDeletedAt(t *time.Time) *StoragePolicyUpdateOne { + if t != nil { + spuo.SetDeletedAt(*t) + } + return spuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (spuo *StoragePolicyUpdateOne) ClearDeletedAt() *StoragePolicyUpdateOne { + spuo.mutation.ClearDeletedAt() + return spuo +} + +// SetName sets the "name" field. +func (spuo *StoragePolicyUpdateOne) SetName(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetName(s) + return spuo +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableName(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetName(*s) + } + return spuo +} + +// SetType sets the "type" field. +func (spuo *StoragePolicyUpdateOne) SetType(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetType(s) + return spuo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableType(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetType(*s) + } + return spuo +} + +// SetServer sets the "server" field. +func (spuo *StoragePolicyUpdateOne) SetServer(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetServer(s) + return spuo +} + +// SetNillableServer sets the "server" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableServer(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetServer(*s) + } + return spuo +} + +// ClearServer clears the value of the "server" field. +func (spuo *StoragePolicyUpdateOne) ClearServer() *StoragePolicyUpdateOne { + spuo.mutation.ClearServer() + return spuo +} + +// SetBucketName sets the "bucket_name" field. +func (spuo *StoragePolicyUpdateOne) SetBucketName(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetBucketName(s) + return spuo +} + +// SetNillableBucketName sets the "bucket_name" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableBucketName(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetBucketName(*s) + } + return spuo +} + +// ClearBucketName clears the value of the "bucket_name" field. +func (spuo *StoragePolicyUpdateOne) ClearBucketName() *StoragePolicyUpdateOne { + spuo.mutation.ClearBucketName() + return spuo +} + +// SetIsPrivate sets the "is_private" field. +func (spuo *StoragePolicyUpdateOne) SetIsPrivate(b bool) *StoragePolicyUpdateOne { + spuo.mutation.SetIsPrivate(b) + return spuo +} + +// SetNillableIsPrivate sets the "is_private" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableIsPrivate(b *bool) *StoragePolicyUpdateOne { + if b != nil { + spuo.SetIsPrivate(*b) + } + return spuo +} + +// ClearIsPrivate clears the value of the "is_private" field. +func (spuo *StoragePolicyUpdateOne) ClearIsPrivate() *StoragePolicyUpdateOne { + spuo.mutation.ClearIsPrivate() + return spuo +} + +// SetAccessKey sets the "access_key" field. +func (spuo *StoragePolicyUpdateOne) SetAccessKey(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetAccessKey(s) + return spuo +} + +// SetNillableAccessKey sets the "access_key" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableAccessKey(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetAccessKey(*s) + } + return spuo +} + +// ClearAccessKey clears the value of the "access_key" field. +func (spuo *StoragePolicyUpdateOne) ClearAccessKey() *StoragePolicyUpdateOne { + spuo.mutation.ClearAccessKey() + return spuo +} + +// SetSecretKey sets the "secret_key" field. +func (spuo *StoragePolicyUpdateOne) SetSecretKey(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetSecretKey(s) + return spuo +} + +// SetNillableSecretKey sets the "secret_key" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableSecretKey(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetSecretKey(*s) + } + return spuo +} + +// ClearSecretKey clears the value of the "secret_key" field. +func (spuo *StoragePolicyUpdateOne) ClearSecretKey() *StoragePolicyUpdateOne { + spuo.mutation.ClearSecretKey() + return spuo +} + +// SetMaxSize sets the "max_size" field. +func (spuo *StoragePolicyUpdateOne) SetMaxSize(i int64) *StoragePolicyUpdateOne { + spuo.mutation.ResetMaxSize() + spuo.mutation.SetMaxSize(i) + return spuo +} + +// SetNillableMaxSize sets the "max_size" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableMaxSize(i *int64) *StoragePolicyUpdateOne { + if i != nil { + spuo.SetMaxSize(*i) + } + return spuo +} + +// AddMaxSize adds i to the "max_size" field. +func (spuo *StoragePolicyUpdateOne) AddMaxSize(i int64) *StoragePolicyUpdateOne { + spuo.mutation.AddMaxSize(i) + return spuo +} + +// ClearMaxSize clears the value of the "max_size" field. +func (spuo *StoragePolicyUpdateOne) ClearMaxSize() *StoragePolicyUpdateOne { + spuo.mutation.ClearMaxSize() + return spuo +} + +// SetDirNameRule sets the "dir_name_rule" field. +func (spuo *StoragePolicyUpdateOne) SetDirNameRule(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetDirNameRule(s) + return spuo +} + +// SetNillableDirNameRule sets the "dir_name_rule" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableDirNameRule(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetDirNameRule(*s) + } + return spuo +} + +// ClearDirNameRule clears the value of the "dir_name_rule" field. +func (spuo *StoragePolicyUpdateOne) ClearDirNameRule() *StoragePolicyUpdateOne { + spuo.mutation.ClearDirNameRule() + return spuo +} + +// SetFileNameRule sets the "file_name_rule" field. +func (spuo *StoragePolicyUpdateOne) SetFileNameRule(s string) *StoragePolicyUpdateOne { + spuo.mutation.SetFileNameRule(s) + return spuo +} + +// SetNillableFileNameRule sets the "file_name_rule" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableFileNameRule(s *string) *StoragePolicyUpdateOne { + if s != nil { + spuo.SetFileNameRule(*s) + } + return spuo +} + +// ClearFileNameRule clears the value of the "file_name_rule" field. +func (spuo *StoragePolicyUpdateOne) ClearFileNameRule() *StoragePolicyUpdateOne { + spuo.mutation.ClearFileNameRule() + return spuo +} + +// SetSettings sets the "settings" field. +func (spuo *StoragePolicyUpdateOne) SetSettings(ts *types.PolicySetting) *StoragePolicyUpdateOne { + spuo.mutation.SetSettings(ts) + return spuo +} + +// ClearSettings clears the value of the "settings" field. +func (spuo *StoragePolicyUpdateOne) ClearSettings() *StoragePolicyUpdateOne { + spuo.mutation.ClearSettings() + return spuo +} + +// SetNodeID sets the "node_id" field. +func (spuo *StoragePolicyUpdateOne) SetNodeID(i int) *StoragePolicyUpdateOne { + spuo.mutation.SetNodeID(i) + return spuo +} + +// SetNillableNodeID sets the "node_id" field if the given value is not nil. +func (spuo *StoragePolicyUpdateOne) SetNillableNodeID(i *int) *StoragePolicyUpdateOne { + if i != nil { + spuo.SetNodeID(*i) + } + return spuo +} + +// ClearNodeID clears the value of the "node_id" field. +func (spuo *StoragePolicyUpdateOne) ClearNodeID() *StoragePolicyUpdateOne { + spuo.mutation.ClearNodeID() + return spuo +} + +// AddUserIDs adds the "users" edge to the User entity by IDs. +func (spuo *StoragePolicyUpdateOne) AddUserIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.AddUserIDs(ids...) + return spuo +} + +// AddUsers adds the "users" edges to the User entity. +func (spuo *StoragePolicyUpdateOne) AddUsers(u ...*User) *StoragePolicyUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return spuo.AddUserIDs(ids...) +} + +// AddGroupIDs adds the "groups" edge to the Group entity by IDs. +func (spuo *StoragePolicyUpdateOne) AddGroupIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.AddGroupIDs(ids...) + return spuo +} + +// AddGroups adds the "groups" edges to the Group entity. +func (spuo *StoragePolicyUpdateOne) AddGroups(g ...*Group) *StoragePolicyUpdateOne { + ids := make([]int, len(g)) + for i := range g { + ids[i] = g[i].ID + } + return spuo.AddGroupIDs(ids...) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (spuo *StoragePolicyUpdateOne) AddFileIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.AddFileIDs(ids...) + return spuo +} + +// AddFiles adds the "files" edges to the File entity. +func (spuo *StoragePolicyUpdateOne) AddFiles(f ...*File) *StoragePolicyUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return spuo.AddFileIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (spuo *StoragePolicyUpdateOne) AddEntityIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.AddEntityIDs(ids...) + return spuo +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (spuo *StoragePolicyUpdateOne) AddEntities(e ...*Entity) *StoragePolicyUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return spuo.AddEntityIDs(ids...) +} + +// SetNode sets the "node" edge to the Node entity. +func (spuo *StoragePolicyUpdateOne) SetNode(n *Node) *StoragePolicyUpdateOne { + return spuo.SetNodeID(n.ID) +} + +// Mutation returns the StoragePolicyMutation object of the builder. +func (spuo *StoragePolicyUpdateOne) Mutation() *StoragePolicyMutation { + return spuo.mutation +} + +// ClearUsers clears all "users" edges to the User entity. +func (spuo *StoragePolicyUpdateOne) ClearUsers() *StoragePolicyUpdateOne { + spuo.mutation.ClearUsers() + return spuo +} + +// RemoveUserIDs removes the "users" edge to User entities by IDs. +func (spuo *StoragePolicyUpdateOne) RemoveUserIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.RemoveUserIDs(ids...) + return spuo +} + +// RemoveUsers removes "users" edges to User entities. +func (spuo *StoragePolicyUpdateOne) RemoveUsers(u ...*User) *StoragePolicyUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return spuo.RemoveUserIDs(ids...) +} + +// ClearGroups clears all "groups" edges to the Group entity. +func (spuo *StoragePolicyUpdateOne) ClearGroups() *StoragePolicyUpdateOne { + spuo.mutation.ClearGroups() + return spuo +} + +// RemoveGroupIDs removes the "groups" edge to Group entities by IDs. +func (spuo *StoragePolicyUpdateOne) RemoveGroupIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.RemoveGroupIDs(ids...) + return spuo +} + +// RemoveGroups removes "groups" edges to Group entities. +func (spuo *StoragePolicyUpdateOne) RemoveGroups(g ...*Group) *StoragePolicyUpdateOne { + ids := make([]int, len(g)) + for i := range g { + ids[i] = g[i].ID + } + return spuo.RemoveGroupIDs(ids...) +} + +// ClearFiles clears all "files" edges to the File entity. +func (spuo *StoragePolicyUpdateOne) ClearFiles() *StoragePolicyUpdateOne { + spuo.mutation.ClearFiles() + return spuo +} + +// RemoveFileIDs removes the "files" edge to File entities by IDs. +func (spuo *StoragePolicyUpdateOne) RemoveFileIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.RemoveFileIDs(ids...) + return spuo +} + +// RemoveFiles removes "files" edges to File entities. +func (spuo *StoragePolicyUpdateOne) RemoveFiles(f ...*File) *StoragePolicyUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return spuo.RemoveFileIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (spuo *StoragePolicyUpdateOne) ClearEntities() *StoragePolicyUpdateOne { + spuo.mutation.ClearEntities() + return spuo +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (spuo *StoragePolicyUpdateOne) RemoveEntityIDs(ids ...int) *StoragePolicyUpdateOne { + spuo.mutation.RemoveEntityIDs(ids...) + return spuo +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (spuo *StoragePolicyUpdateOne) RemoveEntities(e ...*Entity) *StoragePolicyUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return spuo.RemoveEntityIDs(ids...) +} + +// ClearNode clears the "node" edge to the Node entity. +func (spuo *StoragePolicyUpdateOne) ClearNode() *StoragePolicyUpdateOne { + spuo.mutation.ClearNode() + return spuo +} + +// Where appends a list predicates to the StoragePolicyUpdate builder. +func (spuo *StoragePolicyUpdateOne) Where(ps ...predicate.StoragePolicy) *StoragePolicyUpdateOne { + spuo.mutation.Where(ps...) + return spuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (spuo *StoragePolicyUpdateOne) Select(field string, fields ...string) *StoragePolicyUpdateOne { + spuo.fields = append([]string{field}, fields...) + return spuo +} + +// Save executes the query and returns the updated StoragePolicy entity. +func (spuo *StoragePolicyUpdateOne) Save(ctx context.Context) (*StoragePolicy, error) { + if err := spuo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, spuo.sqlSave, spuo.mutation, spuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (spuo *StoragePolicyUpdateOne) SaveX(ctx context.Context) *StoragePolicy { + node, err := spuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (spuo *StoragePolicyUpdateOne) Exec(ctx context.Context) error { + _, err := spuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (spuo *StoragePolicyUpdateOne) ExecX(ctx context.Context) { + if err := spuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (spuo *StoragePolicyUpdateOne) defaults() error { + if _, ok := spuo.mutation.UpdatedAt(); !ok { + if storagepolicy.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized storagepolicy.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := storagepolicy.UpdateDefaultUpdatedAt() + spuo.mutation.SetUpdatedAt(v) + } + return nil +} + +func (spuo *StoragePolicyUpdateOne) sqlSave(ctx context.Context) (_node *StoragePolicy, err error) { + _spec := sqlgraph.NewUpdateSpec(storagepolicy.Table, storagepolicy.Columns, sqlgraph.NewFieldSpec(storagepolicy.FieldID, field.TypeInt)) + id, ok := spuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "StoragePolicy.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := spuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, storagepolicy.FieldID) + for _, f := range fields { + if !storagepolicy.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != storagepolicy.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := spuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := spuo.mutation.UpdatedAt(); ok { + _spec.SetField(storagepolicy.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := spuo.mutation.DeletedAt(); ok { + _spec.SetField(storagepolicy.FieldDeletedAt, field.TypeTime, value) + } + if spuo.mutation.DeletedAtCleared() { + _spec.ClearField(storagepolicy.FieldDeletedAt, field.TypeTime) + } + if value, ok := spuo.mutation.Name(); ok { + _spec.SetField(storagepolicy.FieldName, field.TypeString, value) + } + if value, ok := spuo.mutation.GetType(); ok { + _spec.SetField(storagepolicy.FieldType, field.TypeString, value) + } + if value, ok := spuo.mutation.Server(); ok { + _spec.SetField(storagepolicy.FieldServer, field.TypeString, value) + } + if spuo.mutation.ServerCleared() { + _spec.ClearField(storagepolicy.FieldServer, field.TypeString) + } + if value, ok := spuo.mutation.BucketName(); ok { + _spec.SetField(storagepolicy.FieldBucketName, field.TypeString, value) + } + if spuo.mutation.BucketNameCleared() { + _spec.ClearField(storagepolicy.FieldBucketName, field.TypeString) + } + if value, ok := spuo.mutation.IsPrivate(); ok { + _spec.SetField(storagepolicy.FieldIsPrivate, field.TypeBool, value) + } + if spuo.mutation.IsPrivateCleared() { + _spec.ClearField(storagepolicy.FieldIsPrivate, field.TypeBool) + } + if value, ok := spuo.mutation.AccessKey(); ok { + _spec.SetField(storagepolicy.FieldAccessKey, field.TypeString, value) + } + if spuo.mutation.AccessKeyCleared() { + _spec.ClearField(storagepolicy.FieldAccessKey, field.TypeString) + } + if value, ok := spuo.mutation.SecretKey(); ok { + _spec.SetField(storagepolicy.FieldSecretKey, field.TypeString, value) + } + if spuo.mutation.SecretKeyCleared() { + _spec.ClearField(storagepolicy.FieldSecretKey, field.TypeString) + } + if value, ok := spuo.mutation.MaxSize(); ok { + _spec.SetField(storagepolicy.FieldMaxSize, field.TypeInt64, value) + } + if value, ok := spuo.mutation.AddedMaxSize(); ok { + _spec.AddField(storagepolicy.FieldMaxSize, field.TypeInt64, value) + } + if spuo.mutation.MaxSizeCleared() { + _spec.ClearField(storagepolicy.FieldMaxSize, field.TypeInt64) + } + if value, ok := spuo.mutation.DirNameRule(); ok { + _spec.SetField(storagepolicy.FieldDirNameRule, field.TypeString, value) + } + if spuo.mutation.DirNameRuleCleared() { + _spec.ClearField(storagepolicy.FieldDirNameRule, field.TypeString) + } + if value, ok := spuo.mutation.FileNameRule(); ok { + _spec.SetField(storagepolicy.FieldFileNameRule, field.TypeString, value) + } + if spuo.mutation.FileNameRuleCleared() { + _spec.ClearField(storagepolicy.FieldFileNameRule, field.TypeString) + } + if value, ok := spuo.mutation.Settings(); ok { + _spec.SetField(storagepolicy.FieldSettings, field.TypeJSON, value) + } + if spuo.mutation.SettingsCleared() { + _spec.ClearField(storagepolicy.FieldSettings, field.TypeJSON) + } + if spuo.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.RemovedUsersIDs(); len(nodes) > 0 && !spuo.mutation.UsersCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.UsersIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.UsersTable, + Columns: []string{storagepolicy.UsersColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spuo.mutation.GroupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.RemovedGroupsIDs(); len(nodes) > 0 && !spuo.mutation.GroupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.GroupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.GroupsTable, + Columns: []string{storagepolicy.GroupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spuo.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.RemovedFilesIDs(); len(nodes) > 0 && !spuo.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.FilesTable, + Columns: []string{storagepolicy.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !spuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: storagepolicy.EntitiesTable, + Columns: []string{storagepolicy.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if spuo.mutation.NodeCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: storagepolicy.NodeTable, + Columns: []string{storagepolicy.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := spuo.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: storagepolicy.NodeTable, + Columns: []string{storagepolicy.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &StoragePolicy{config: spuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, spuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{storagepolicy.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + spuo.mutation.done = true + return _node, nil +} diff --git a/ent/task.go b/ent/task.go new file mode 100644 index 00000000..67a14b99 --- /dev/null +++ b/ent/task.go @@ -0,0 +1,243 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// Task is the model entity for the Task schema. +type Task struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Type holds the value of the "type" field. + Type string `json:"type,omitempty"` + // Status holds the value of the "status" field. + Status task.Status `json:"status,omitempty"` + // PublicState holds the value of the "public_state" field. + PublicState *types.TaskPublicState `json:"public_state,omitempty"` + // PrivateState holds the value of the "private_state" field. + PrivateState string `json:"private_state,omitempty"` + // CorrelationID holds the value of the "correlation_id" field. + CorrelationID uuid.UUID `json:"correlation_id,omitempty"` + // UserTasks holds the value of the "user_tasks" field. + UserTasks int `json:"user_tasks,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the TaskQuery when eager-loading is set. + Edges TaskEdges `json:"edges"` + selectValues sql.SelectValues +} + +// TaskEdges holds the relations/edges for other nodes in the graph. +type TaskEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e TaskEdges) UserOrErr() (*User, error) { + if e.loadedTypes[0] { + if e.User == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: user.Label} + } + return e.User, nil + } + return nil, &NotLoadedError{edge: "user"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Task) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case task.FieldPublicState: + values[i] = new([]byte) + case task.FieldID, task.FieldUserTasks: + values[i] = new(sql.NullInt64) + case task.FieldType, task.FieldStatus, task.FieldPrivateState: + values[i] = new(sql.NullString) + case task.FieldCreatedAt, task.FieldUpdatedAt, task.FieldDeletedAt: + values[i] = new(sql.NullTime) + case task.FieldCorrelationID: + values[i] = new(uuid.UUID) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Task fields. +func (t *Task) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case task.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + t.ID = int(value.Int64) + case task.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + t.CreatedAt = value.Time + } + case task.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + t.UpdatedAt = value.Time + } + case task.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + t.DeletedAt = new(time.Time) + *t.DeletedAt = value.Time + } + case task.FieldType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + t.Type = value.String + } + case task.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + t.Status = task.Status(value.String) + } + case task.FieldPublicState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field public_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &t.PublicState); err != nil { + return fmt.Errorf("unmarshal field public_state: %w", err) + } + } + case task.FieldPrivateState: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field private_state", values[i]) + } else if value.Valid { + t.PrivateState = value.String + } + case task.FieldCorrelationID: + if value, ok := values[i].(*uuid.UUID); !ok { + return fmt.Errorf("unexpected type %T for field correlation_id", values[i]) + } else if value != nil { + t.CorrelationID = *value + } + case task.FieldUserTasks: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_tasks", values[i]) + } else if value.Valid { + t.UserTasks = int(value.Int64) + } + default: + t.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Task. +// This includes values selected through modifiers, order, etc. +func (t *Task) Value(name string) (ent.Value, error) { + return t.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the Task entity. +func (t *Task) QueryUser() *UserQuery { + return NewTaskClient(t.config).QueryUser(t) +} + +// Update returns a builder for updating this Task. +// Note that you need to call Task.Unwrap() before calling this method if this Task +// was returned from a transaction, and the transaction was committed or rolled back. +func (t *Task) Update() *TaskUpdateOne { + return NewTaskClient(t.config).UpdateOne(t) +} + +// Unwrap unwraps the Task entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (t *Task) Unwrap() *Task { + _tx, ok := t.config.driver.(*txDriver) + if !ok { + panic("ent: Task is not a transactional entity") + } + t.config.driver = _tx.drv + return t +} + +// String implements the fmt.Stringer. +func (t *Task) String() string { + var builder strings.Builder + builder.WriteString("Task(") + builder.WriteString(fmt.Sprintf("id=%v, ", t.ID)) + builder.WriteString("created_at=") + builder.WriteString(t.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(t.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := t.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("type=") + builder.WriteString(t.Type) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", t.Status)) + builder.WriteString(", ") + builder.WriteString("public_state=") + builder.WriteString(fmt.Sprintf("%v", t.PublicState)) + builder.WriteString(", ") + builder.WriteString("private_state=") + builder.WriteString(t.PrivateState) + builder.WriteString(", ") + builder.WriteString("correlation_id=") + builder.WriteString(fmt.Sprintf("%v", t.CorrelationID)) + builder.WriteString(", ") + builder.WriteString("user_tasks=") + builder.WriteString(fmt.Sprintf("%v", t.UserTasks)) + builder.WriteByte(')') + return builder.String() +} + +// SetUser manually set the edge as loaded state. +func (e *Task) SetUser(v *User) { + e.Edges.User = v + e.Edges.loadedTypes[0] = true +} + +// Tasks is a parsable slice of Task. +type Tasks []*Task diff --git a/ent/task/task.go b/ent/task/task.go new file mode 100644 index 00000000..7f96f1c9 --- /dev/null +++ b/ent/task/task.go @@ -0,0 +1,180 @@ +// Code generated by ent, DO NOT EDIT. + +package task + +import ( + "fmt" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the task type in the database. + Label = "task" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldPublicState holds the string denoting the public_state field in the database. + FieldPublicState = "public_state" + // FieldPrivateState holds the string denoting the private_state field in the database. + FieldPrivateState = "private_state" + // FieldCorrelationID holds the string denoting the correlation_id field in the database. + FieldCorrelationID = "correlation_id" + // FieldUserTasks holds the string denoting the user_tasks field in the database. + FieldUserTasks = "user_tasks" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // Table holds the table name of the task in the database. + Table = "tasks" + // UserTable is the table that holds the user relation/edge. + UserTable = "tasks" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_tasks" +) + +// Columns holds all SQL columns for task fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldType, + FieldStatus, + FieldPublicState, + FieldPrivateState, + FieldCorrelationID, + FieldUserTasks, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// Status defines the type for the "status" enum field. +type Status string + +// StatusQueued is the default value of the Status enum. +const DefaultStatus = StatusQueued + +// Status values. +const ( + StatusQueued Status = "queued" + StatusProcessing Status = "processing" + StatusSuspending Status = "suspending" + StatusError Status = "error" + StatusCanceled Status = "canceled" + StatusCompleted Status = "completed" +) + +func (s Status) String() string { + return string(s) +} + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s Status) error { + switch s { + case StatusQueued, StatusProcessing, StatusSuspending, StatusError, StatusCanceled, StatusCompleted: + return nil + default: + return fmt.Errorf("task: invalid enum value for status field: %q", s) + } +} + +// OrderOption defines the ordering options for the Task queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByPrivateState orders the results by the private_state field. +func ByPrivateState(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrivateState, opts...).ToFunc() +} + +// ByCorrelationID orders the results by the correlation_id field. +func ByCorrelationID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCorrelationID, opts...).ToFunc() +} + +// ByUserTasks orders the results by the user_tasks field. +func ByUserTasks(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserTasks, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} diff --git a/ent/task/where.go b/ent/task/where.go new file mode 100644 index 00000000..59a7ced6 --- /dev/null +++ b/ent/task/where.go @@ -0,0 +1,500 @@ +// Code generated by ent, DO NOT EDIT. + +package task + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/gofrs/uuid" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Task { + return predicate.Task(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Task { + return predicate.Task(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Task { + return predicate.Task(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Type applies equality check predicate on the "type" field. It's identical to TypeEQ. +func Type(v string) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldType, v)) +} + +// PrivateState applies equality check predicate on the "private_state" field. It's identical to PrivateStateEQ. +func PrivateState(v string) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldPrivateState, v)) +} + +// CorrelationID applies equality check predicate on the "correlation_id" field. It's identical to CorrelationIDEQ. +func CorrelationID(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldCorrelationID, v)) +} + +// UserTasks applies equality check predicate on the "user_tasks" field. It's identical to UserTasksEQ. +func UserTasks(v int) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldUserTasks, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.Task { + return predicate.Task(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.Task { + return predicate.Task(sql.FieldNotNull(FieldDeletedAt)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v string) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v string) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...string) predicate.Task { + return predicate.Task(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...string) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldType, vs...)) +} + +// TypeGT applies the GT predicate on the "type" field. +func TypeGT(v string) predicate.Task { + return predicate.Task(sql.FieldGT(FieldType, v)) +} + +// TypeGTE applies the GTE predicate on the "type" field. +func TypeGTE(v string) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldType, v)) +} + +// TypeLT applies the LT predicate on the "type" field. +func TypeLT(v string) predicate.Task { + return predicate.Task(sql.FieldLT(FieldType, v)) +} + +// TypeLTE applies the LTE predicate on the "type" field. +func TypeLTE(v string) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldType, v)) +} + +// TypeContains applies the Contains predicate on the "type" field. +func TypeContains(v string) predicate.Task { + return predicate.Task(sql.FieldContains(FieldType, v)) +} + +// TypeHasPrefix applies the HasPrefix predicate on the "type" field. +func TypeHasPrefix(v string) predicate.Task { + return predicate.Task(sql.FieldHasPrefix(FieldType, v)) +} + +// TypeHasSuffix applies the HasSuffix predicate on the "type" field. +func TypeHasSuffix(v string) predicate.Task { + return predicate.Task(sql.FieldHasSuffix(FieldType, v)) +} + +// TypeEqualFold applies the EqualFold predicate on the "type" field. +func TypeEqualFold(v string) predicate.Task { + return predicate.Task(sql.FieldEqualFold(FieldType, v)) +} + +// TypeContainsFold applies the ContainsFold predicate on the "type" field. +func TypeContainsFold(v string) predicate.Task { + return predicate.Task(sql.FieldContainsFold(FieldType, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v Status) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v Status) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...Status) predicate.Task { + return predicate.Task(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...Status) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldStatus, vs...)) +} + +// PrivateStateEQ applies the EQ predicate on the "private_state" field. +func PrivateStateEQ(v string) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldPrivateState, v)) +} + +// PrivateStateNEQ applies the NEQ predicate on the "private_state" field. +func PrivateStateNEQ(v string) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldPrivateState, v)) +} + +// PrivateStateIn applies the In predicate on the "private_state" field. +func PrivateStateIn(vs ...string) predicate.Task { + return predicate.Task(sql.FieldIn(FieldPrivateState, vs...)) +} + +// PrivateStateNotIn applies the NotIn predicate on the "private_state" field. +func PrivateStateNotIn(vs ...string) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldPrivateState, vs...)) +} + +// PrivateStateGT applies the GT predicate on the "private_state" field. +func PrivateStateGT(v string) predicate.Task { + return predicate.Task(sql.FieldGT(FieldPrivateState, v)) +} + +// PrivateStateGTE applies the GTE predicate on the "private_state" field. +func PrivateStateGTE(v string) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldPrivateState, v)) +} + +// PrivateStateLT applies the LT predicate on the "private_state" field. +func PrivateStateLT(v string) predicate.Task { + return predicate.Task(sql.FieldLT(FieldPrivateState, v)) +} + +// PrivateStateLTE applies the LTE predicate on the "private_state" field. +func PrivateStateLTE(v string) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldPrivateState, v)) +} + +// PrivateStateContains applies the Contains predicate on the "private_state" field. +func PrivateStateContains(v string) predicate.Task { + return predicate.Task(sql.FieldContains(FieldPrivateState, v)) +} + +// PrivateStateHasPrefix applies the HasPrefix predicate on the "private_state" field. +func PrivateStateHasPrefix(v string) predicate.Task { + return predicate.Task(sql.FieldHasPrefix(FieldPrivateState, v)) +} + +// PrivateStateHasSuffix applies the HasSuffix predicate on the "private_state" field. +func PrivateStateHasSuffix(v string) predicate.Task { + return predicate.Task(sql.FieldHasSuffix(FieldPrivateState, v)) +} + +// PrivateStateIsNil applies the IsNil predicate on the "private_state" field. +func PrivateStateIsNil() predicate.Task { + return predicate.Task(sql.FieldIsNull(FieldPrivateState)) +} + +// PrivateStateNotNil applies the NotNil predicate on the "private_state" field. +func PrivateStateNotNil() predicate.Task { + return predicate.Task(sql.FieldNotNull(FieldPrivateState)) +} + +// PrivateStateEqualFold applies the EqualFold predicate on the "private_state" field. +func PrivateStateEqualFold(v string) predicate.Task { + return predicate.Task(sql.FieldEqualFold(FieldPrivateState, v)) +} + +// PrivateStateContainsFold applies the ContainsFold predicate on the "private_state" field. +func PrivateStateContainsFold(v string) predicate.Task { + return predicate.Task(sql.FieldContainsFold(FieldPrivateState, v)) +} + +// CorrelationIDEQ applies the EQ predicate on the "correlation_id" field. +func CorrelationIDEQ(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldCorrelationID, v)) +} + +// CorrelationIDNEQ applies the NEQ predicate on the "correlation_id" field. +func CorrelationIDNEQ(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldCorrelationID, v)) +} + +// CorrelationIDIn applies the In predicate on the "correlation_id" field. +func CorrelationIDIn(vs ...uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldIn(FieldCorrelationID, vs...)) +} + +// CorrelationIDNotIn applies the NotIn predicate on the "correlation_id" field. +func CorrelationIDNotIn(vs ...uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldCorrelationID, vs...)) +} + +// CorrelationIDGT applies the GT predicate on the "correlation_id" field. +func CorrelationIDGT(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldGT(FieldCorrelationID, v)) +} + +// CorrelationIDGTE applies the GTE predicate on the "correlation_id" field. +func CorrelationIDGTE(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldGTE(FieldCorrelationID, v)) +} + +// CorrelationIDLT applies the LT predicate on the "correlation_id" field. +func CorrelationIDLT(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldLT(FieldCorrelationID, v)) +} + +// CorrelationIDLTE applies the LTE predicate on the "correlation_id" field. +func CorrelationIDLTE(v uuid.UUID) predicate.Task { + return predicate.Task(sql.FieldLTE(FieldCorrelationID, v)) +} + +// CorrelationIDIsNil applies the IsNil predicate on the "correlation_id" field. +func CorrelationIDIsNil() predicate.Task { + return predicate.Task(sql.FieldIsNull(FieldCorrelationID)) +} + +// CorrelationIDNotNil applies the NotNil predicate on the "correlation_id" field. +func CorrelationIDNotNil() predicate.Task { + return predicate.Task(sql.FieldNotNull(FieldCorrelationID)) +} + +// UserTasksEQ applies the EQ predicate on the "user_tasks" field. +func UserTasksEQ(v int) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldUserTasks, v)) +} + +// UserTasksNEQ applies the NEQ predicate on the "user_tasks" field. +func UserTasksNEQ(v int) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldUserTasks, v)) +} + +// UserTasksIn applies the In predicate on the "user_tasks" field. +func UserTasksIn(vs ...int) predicate.Task { + return predicate.Task(sql.FieldIn(FieldUserTasks, vs...)) +} + +// UserTasksNotIn applies the NotIn predicate on the "user_tasks" field. +func UserTasksNotIn(vs ...int) predicate.Task { + return predicate.Task(sql.FieldNotIn(FieldUserTasks, vs...)) +} + +// UserTasksIsNil applies the IsNil predicate on the "user_tasks" field. +func UserTasksIsNil() predicate.Task { + return predicate.Task(sql.FieldIsNull(FieldUserTasks)) +} + +// UserTasksNotNil applies the NotNil predicate on the "user_tasks" field. +func UserTasksNotNil() predicate.Task { + return predicate.Task(sql.FieldNotNull(FieldUserTasks)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.Task { + return predicate.Task(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.Task { + return predicate.Task(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Task) predicate.Task { + return predicate.Task(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Task) predicate.Task { + return predicate.Task(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Task) predicate.Task { + return predicate.Task(sql.NotPredicates(p)) +} diff --git a/ent/task_create.go b/ent/task_create.go new file mode 100644 index 00000000..e1d40b27 --- /dev/null +++ b/ent/task_create.go @@ -0,0 +1,1000 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/gofrs/uuid" +) + +// TaskCreate is the builder for creating a Task entity. +type TaskCreate struct { + config + mutation *TaskMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (tc *TaskCreate) SetCreatedAt(t time.Time) *TaskCreate { + tc.mutation.SetCreatedAt(t) + return tc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (tc *TaskCreate) SetNillableCreatedAt(t *time.Time) *TaskCreate { + if t != nil { + tc.SetCreatedAt(*t) + } + return tc +} + +// SetUpdatedAt sets the "updated_at" field. +func (tc *TaskCreate) SetUpdatedAt(t time.Time) *TaskCreate { + tc.mutation.SetUpdatedAt(t) + return tc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (tc *TaskCreate) SetNillableUpdatedAt(t *time.Time) *TaskCreate { + if t != nil { + tc.SetUpdatedAt(*t) + } + return tc +} + +// SetDeletedAt sets the "deleted_at" field. +func (tc *TaskCreate) SetDeletedAt(t time.Time) *TaskCreate { + tc.mutation.SetDeletedAt(t) + return tc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (tc *TaskCreate) SetNillableDeletedAt(t *time.Time) *TaskCreate { + if t != nil { + tc.SetDeletedAt(*t) + } + return tc +} + +// SetType sets the "type" field. +func (tc *TaskCreate) SetType(s string) *TaskCreate { + tc.mutation.SetType(s) + return tc +} + +// SetStatus sets the "status" field. +func (tc *TaskCreate) SetStatus(t task.Status) *TaskCreate { + tc.mutation.SetStatus(t) + return tc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (tc *TaskCreate) SetNillableStatus(t *task.Status) *TaskCreate { + if t != nil { + tc.SetStatus(*t) + } + return tc +} + +// SetPublicState sets the "public_state" field. +func (tc *TaskCreate) SetPublicState(tps *types.TaskPublicState) *TaskCreate { + tc.mutation.SetPublicState(tps) + return tc +} + +// SetPrivateState sets the "private_state" field. +func (tc *TaskCreate) SetPrivateState(s string) *TaskCreate { + tc.mutation.SetPrivateState(s) + return tc +} + +// SetNillablePrivateState sets the "private_state" field if the given value is not nil. +func (tc *TaskCreate) SetNillablePrivateState(s *string) *TaskCreate { + if s != nil { + tc.SetPrivateState(*s) + } + return tc +} + +// SetCorrelationID sets the "correlation_id" field. +func (tc *TaskCreate) SetCorrelationID(u uuid.UUID) *TaskCreate { + tc.mutation.SetCorrelationID(u) + return tc +} + +// SetNillableCorrelationID sets the "correlation_id" field if the given value is not nil. +func (tc *TaskCreate) SetNillableCorrelationID(u *uuid.UUID) *TaskCreate { + if u != nil { + tc.SetCorrelationID(*u) + } + return tc +} + +// SetUserTasks sets the "user_tasks" field. +func (tc *TaskCreate) SetUserTasks(i int) *TaskCreate { + tc.mutation.SetUserTasks(i) + return tc +} + +// SetNillableUserTasks sets the "user_tasks" field if the given value is not nil. +func (tc *TaskCreate) SetNillableUserTasks(i *int) *TaskCreate { + if i != nil { + tc.SetUserTasks(*i) + } + return tc +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (tc *TaskCreate) SetUserID(id int) *TaskCreate { + tc.mutation.SetUserID(id) + return tc +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (tc *TaskCreate) SetNillableUserID(id *int) *TaskCreate { + if id != nil { + tc = tc.SetUserID(*id) + } + return tc +} + +// SetUser sets the "user" edge to the User entity. +func (tc *TaskCreate) SetUser(u *User) *TaskCreate { + return tc.SetUserID(u.ID) +} + +// Mutation returns the TaskMutation object of the builder. +func (tc *TaskCreate) Mutation() *TaskMutation { + return tc.mutation +} + +// Save creates the Task in the database. +func (tc *TaskCreate) Save(ctx context.Context) (*Task, error) { + if err := tc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, tc.sqlSave, tc.mutation, tc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (tc *TaskCreate) SaveX(ctx context.Context) *Task { + v, err := tc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (tc *TaskCreate) Exec(ctx context.Context) error { + _, err := tc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (tc *TaskCreate) ExecX(ctx context.Context) { + if err := tc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (tc *TaskCreate) defaults() error { + if _, ok := tc.mutation.CreatedAt(); !ok { + if task.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized task.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := task.DefaultCreatedAt() + tc.mutation.SetCreatedAt(v) + } + if _, ok := tc.mutation.UpdatedAt(); !ok { + if task.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized task.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := task.DefaultUpdatedAt() + tc.mutation.SetUpdatedAt(v) + } + if _, ok := tc.mutation.Status(); !ok { + v := task.DefaultStatus + tc.mutation.SetStatus(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (tc *TaskCreate) check() error { + if _, ok := tc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Task.created_at"`)} + } + if _, ok := tc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Task.updated_at"`)} + } + if _, ok := tc.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "Task.type"`)} + } + if _, ok := tc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Task.status"`)} + } + if v, ok := tc.mutation.Status(); ok { + if err := task.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Task.status": %w`, err)} + } + } + if _, ok := tc.mutation.PublicState(); !ok { + return &ValidationError{Name: "public_state", err: errors.New(`ent: missing required field "Task.public_state"`)} + } + return nil +} + +func (tc *TaskCreate) sqlSave(ctx context.Context) (*Task, error) { + if err := tc.check(); err != nil { + return nil, err + } + _node, _spec := tc.createSpec() + if err := sqlgraph.CreateNode(ctx, tc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + tc.mutation.id = &_node.ID + tc.mutation.done = true + return _node, nil +} + +func (tc *TaskCreate) createSpec() (*Task, *sqlgraph.CreateSpec) { + var ( + _node = &Task{config: tc.config} + _spec = sqlgraph.NewCreateSpec(task.Table, sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt)) + ) + + if id, ok := tc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = tc.conflict + if value, ok := tc.mutation.CreatedAt(); ok { + _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := tc.mutation.UpdatedAt(); ok { + _spec.SetField(task.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := tc.mutation.DeletedAt(); ok { + _spec.SetField(task.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := tc.mutation.GetType(); ok { + _spec.SetField(task.FieldType, field.TypeString, value) + _node.Type = value + } + if value, ok := tc.mutation.Status(); ok { + _spec.SetField(task.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := tc.mutation.PublicState(); ok { + _spec.SetField(task.FieldPublicState, field.TypeJSON, value) + _node.PublicState = value + } + if value, ok := tc.mutation.PrivateState(); ok { + _spec.SetField(task.FieldPrivateState, field.TypeString, value) + _node.PrivateState = value + } + if value, ok := tc.mutation.CorrelationID(); ok { + _spec.SetField(task.FieldCorrelationID, field.TypeUUID, value) + _node.CorrelationID = value + } + if nodes := tc.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: task.UserTable, + Columns: []string{task.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserTasks = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Task.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.TaskUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (tc *TaskCreate) OnConflict(opts ...sql.ConflictOption) *TaskUpsertOne { + tc.conflict = opts + return &TaskUpsertOne{ + create: tc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (tc *TaskCreate) OnConflictColumns(columns ...string) *TaskUpsertOne { + tc.conflict = append(tc.conflict, sql.ConflictColumns(columns...)) + return &TaskUpsertOne{ + create: tc, + } +} + +type ( + // TaskUpsertOne is the builder for "upsert"-ing + // one Task node. + TaskUpsertOne struct { + create *TaskCreate + } + + // TaskUpsert is the "OnConflict" setter. + TaskUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *TaskUpsert) SetUpdatedAt(v time.Time) *TaskUpsert { + u.Set(task.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TaskUpsert) UpdateUpdatedAt() *TaskUpsert { + u.SetExcluded(task.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *TaskUpsert) SetDeletedAt(v time.Time) *TaskUpsert { + u.Set(task.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *TaskUpsert) UpdateDeletedAt() *TaskUpsert { + u.SetExcluded(task.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *TaskUpsert) ClearDeletedAt() *TaskUpsert { + u.SetNull(task.FieldDeletedAt) + return u +} + +// SetType sets the "type" field. +func (u *TaskUpsert) SetType(v string) *TaskUpsert { + u.Set(task.FieldType, v) + return u +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *TaskUpsert) UpdateType() *TaskUpsert { + u.SetExcluded(task.FieldType) + return u +} + +// SetStatus sets the "status" field. +func (u *TaskUpsert) SetStatus(v task.Status) *TaskUpsert { + u.Set(task.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *TaskUpsert) UpdateStatus() *TaskUpsert { + u.SetExcluded(task.FieldStatus) + return u +} + +// SetPublicState sets the "public_state" field. +func (u *TaskUpsert) SetPublicState(v *types.TaskPublicState) *TaskUpsert { + u.Set(task.FieldPublicState, v) + return u +} + +// UpdatePublicState sets the "public_state" field to the value that was provided on create. +func (u *TaskUpsert) UpdatePublicState() *TaskUpsert { + u.SetExcluded(task.FieldPublicState) + return u +} + +// SetPrivateState sets the "private_state" field. +func (u *TaskUpsert) SetPrivateState(v string) *TaskUpsert { + u.Set(task.FieldPrivateState, v) + return u +} + +// UpdatePrivateState sets the "private_state" field to the value that was provided on create. +func (u *TaskUpsert) UpdatePrivateState() *TaskUpsert { + u.SetExcluded(task.FieldPrivateState) + return u +} + +// ClearPrivateState clears the value of the "private_state" field. +func (u *TaskUpsert) ClearPrivateState() *TaskUpsert { + u.SetNull(task.FieldPrivateState) + return u +} + +// SetUserTasks sets the "user_tasks" field. +func (u *TaskUpsert) SetUserTasks(v int) *TaskUpsert { + u.Set(task.FieldUserTasks, v) + return u +} + +// UpdateUserTasks sets the "user_tasks" field to the value that was provided on create. +func (u *TaskUpsert) UpdateUserTasks() *TaskUpsert { + u.SetExcluded(task.FieldUserTasks) + return u +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (u *TaskUpsert) ClearUserTasks() *TaskUpsert { + u.SetNull(task.FieldUserTasks) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *TaskUpsertOne) UpdateNewValues() *TaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(task.FieldCreatedAt) + } + if _, exists := u.create.mutation.CorrelationID(); exists { + s.SetIgnore(task.FieldCorrelationID) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *TaskUpsertOne) Ignore() *TaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *TaskUpsertOne) DoNothing() *TaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the TaskCreate.OnConflict +// documentation for more info. +func (u *TaskUpsertOne) Update(set func(*TaskUpsert)) *TaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&TaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *TaskUpsertOne) SetUpdatedAt(v time.Time) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateUpdatedAt() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *TaskUpsertOne) SetDeletedAt(v time.Time) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateDeletedAt() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *TaskUpsertOne) ClearDeletedAt() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *TaskUpsertOne) SetType(v string) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateType() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateType() + }) +} + +// SetStatus sets the "status" field. +func (u *TaskUpsertOne) SetStatus(v task.Status) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateStatus() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateStatus() + }) +} + +// SetPublicState sets the "public_state" field. +func (u *TaskUpsertOne) SetPublicState(v *types.TaskPublicState) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetPublicState(v) + }) +} + +// UpdatePublicState sets the "public_state" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdatePublicState() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdatePublicState() + }) +} + +// SetPrivateState sets the "private_state" field. +func (u *TaskUpsertOne) SetPrivateState(v string) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetPrivateState(v) + }) +} + +// UpdatePrivateState sets the "private_state" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdatePrivateState() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdatePrivateState() + }) +} + +// ClearPrivateState clears the value of the "private_state" field. +func (u *TaskUpsertOne) ClearPrivateState() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.ClearPrivateState() + }) +} + +// SetUserTasks sets the "user_tasks" field. +func (u *TaskUpsertOne) SetUserTasks(v int) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetUserTasks(v) + }) +} + +// UpdateUserTasks sets the "user_tasks" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateUserTasks() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateUserTasks() + }) +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (u *TaskUpsertOne) ClearUserTasks() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.ClearUserTasks() + }) +} + +// Exec executes the query. +func (u *TaskUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for TaskCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *TaskUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *TaskUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *TaskUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *TaskCreate) SetRawID(t int) *TaskCreate { + m.mutation.SetRawID(t) + return m +} + +// TaskCreateBulk is the builder for creating many Task entities in bulk. +type TaskCreateBulk struct { + config + err error + builders []*TaskCreate + conflict []sql.ConflictOption +} + +// Save creates the Task entities in the database. +func (tcb *TaskCreateBulk) Save(ctx context.Context) ([]*Task, error) { + if tcb.err != nil { + return nil, tcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(tcb.builders)) + nodes := make([]*Task, len(tcb.builders)) + mutators := make([]Mutator, len(tcb.builders)) + for i := range tcb.builders { + func(i int, root context.Context) { + builder := tcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*TaskMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, tcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = tcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, tcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, tcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (tcb *TaskCreateBulk) SaveX(ctx context.Context) []*Task { + v, err := tcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (tcb *TaskCreateBulk) Exec(ctx context.Context) error { + _, err := tcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (tcb *TaskCreateBulk) ExecX(ctx context.Context) { + if err := tcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.Task.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.TaskUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (tcb *TaskCreateBulk) OnConflict(opts ...sql.ConflictOption) *TaskUpsertBulk { + tcb.conflict = opts + return &TaskUpsertBulk{ + create: tcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (tcb *TaskCreateBulk) OnConflictColumns(columns ...string) *TaskUpsertBulk { + tcb.conflict = append(tcb.conflict, sql.ConflictColumns(columns...)) + return &TaskUpsertBulk{ + create: tcb, + } +} + +// TaskUpsertBulk is the builder for "upsert"-ing +// a bulk of Task nodes. +type TaskUpsertBulk struct { + create *TaskCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *TaskUpsertBulk) UpdateNewValues() *TaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(task.FieldCreatedAt) + } + if _, exists := b.mutation.CorrelationID(); exists { + s.SetIgnore(task.FieldCorrelationID) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.Task.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *TaskUpsertBulk) Ignore() *TaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *TaskUpsertBulk) DoNothing() *TaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the TaskCreateBulk.OnConflict +// documentation for more info. +func (u *TaskUpsertBulk) Update(set func(*TaskUpsert)) *TaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&TaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *TaskUpsertBulk) SetUpdatedAt(v time.Time) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateUpdatedAt() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *TaskUpsertBulk) SetDeletedAt(v time.Time) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateDeletedAt() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *TaskUpsertBulk) ClearDeletedAt() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.ClearDeletedAt() + }) +} + +// SetType sets the "type" field. +func (u *TaskUpsertBulk) SetType(v string) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetType(v) + }) +} + +// UpdateType sets the "type" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateType() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateType() + }) +} + +// SetStatus sets the "status" field. +func (u *TaskUpsertBulk) SetStatus(v task.Status) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateStatus() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateStatus() + }) +} + +// SetPublicState sets the "public_state" field. +func (u *TaskUpsertBulk) SetPublicState(v *types.TaskPublicState) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetPublicState(v) + }) +} + +// UpdatePublicState sets the "public_state" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdatePublicState() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdatePublicState() + }) +} + +// SetPrivateState sets the "private_state" field. +func (u *TaskUpsertBulk) SetPrivateState(v string) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetPrivateState(v) + }) +} + +// UpdatePrivateState sets the "private_state" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdatePrivateState() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdatePrivateState() + }) +} + +// ClearPrivateState clears the value of the "private_state" field. +func (u *TaskUpsertBulk) ClearPrivateState() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.ClearPrivateState() + }) +} + +// SetUserTasks sets the "user_tasks" field. +func (u *TaskUpsertBulk) SetUserTasks(v int) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetUserTasks(v) + }) +} + +// UpdateUserTasks sets the "user_tasks" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateUserTasks() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateUserTasks() + }) +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (u *TaskUpsertBulk) ClearUserTasks() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.ClearUserTasks() + }) +} + +// Exec executes the query. +func (u *TaskUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the TaskCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for TaskCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *TaskUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/task_delete.go b/ent/task_delete.go new file mode 100644 index 00000000..9425ec2c --- /dev/null +++ b/ent/task_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/task" +) + +// TaskDelete is the builder for deleting a Task entity. +type TaskDelete struct { + config + hooks []Hook + mutation *TaskMutation +} + +// Where appends a list predicates to the TaskDelete builder. +func (td *TaskDelete) Where(ps ...predicate.Task) *TaskDelete { + td.mutation.Where(ps...) + return td +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (td *TaskDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, td.sqlExec, td.mutation, td.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (td *TaskDelete) ExecX(ctx context.Context) int { + n, err := td.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (td *TaskDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(task.Table, sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt)) + if ps := td.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, td.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + td.mutation.done = true + return affected, err +} + +// TaskDeleteOne is the builder for deleting a single Task entity. +type TaskDeleteOne struct { + td *TaskDelete +} + +// Where appends a list predicates to the TaskDelete builder. +func (tdo *TaskDeleteOne) Where(ps ...predicate.Task) *TaskDeleteOne { + tdo.td.mutation.Where(ps...) + return tdo +} + +// Exec executes the deletion query. +func (tdo *TaskDeleteOne) Exec(ctx context.Context) error { + n, err := tdo.td.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{task.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (tdo *TaskDeleteOne) ExecX(ctx context.Context) { + if err := tdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/task_query.go b/ent/task_query.go new file mode 100644 index 00000000..39155fd1 --- /dev/null +++ b/ent/task_query.go @@ -0,0 +1,605 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// TaskQuery is the builder for querying Task entities. +type TaskQuery struct { + config + ctx *QueryContext + order []task.OrderOption + inters []Interceptor + predicates []predicate.Task + withUser *UserQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the TaskQuery builder. +func (tq *TaskQuery) Where(ps ...predicate.Task) *TaskQuery { + tq.predicates = append(tq.predicates, ps...) + return tq +} + +// Limit the number of records to be returned by this query. +func (tq *TaskQuery) Limit(limit int) *TaskQuery { + tq.ctx.Limit = &limit + return tq +} + +// Offset to start from. +func (tq *TaskQuery) Offset(offset int) *TaskQuery { + tq.ctx.Offset = &offset + return tq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (tq *TaskQuery) Unique(unique bool) *TaskQuery { + tq.ctx.Unique = &unique + return tq +} + +// Order specifies how the records should be ordered. +func (tq *TaskQuery) Order(o ...task.OrderOption) *TaskQuery { + tq.order = append(tq.order, o...) + return tq +} + +// QueryUser chains the current query on the "user" edge. +func (tq *TaskQuery) QueryUser() *UserQuery { + query := (&UserClient{config: tq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := tq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := tq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(task.Table, task.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, task.UserTable, task.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(tq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Task entity from the query. +// Returns a *NotFoundError when no Task was found. +func (tq *TaskQuery) First(ctx context.Context) (*Task, error) { + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{task.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (tq *TaskQuery) FirstX(ctx context.Context) *Task { + node, err := tq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Task ID from the query. +// Returns a *NotFoundError when no Task ID was found. +func (tq *TaskQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{task.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (tq *TaskQuery) FirstIDX(ctx context.Context) int { + id, err := tq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Task entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Task entity is found. +// Returns a *NotFoundError when no Task entities are found. +func (tq *TaskQuery) Only(ctx context.Context) (*Task, error) { + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{task.Label} + default: + return nil, &NotSingularError{task.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (tq *TaskQuery) OnlyX(ctx context.Context) *Task { + node, err := tq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Task ID in the query. +// Returns a *NotSingularError when more than one Task ID is found. +// Returns a *NotFoundError when no entities are found. +func (tq *TaskQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{task.Label} + default: + err = &NotSingularError{task.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (tq *TaskQuery) OnlyIDX(ctx context.Context) int { + id, err := tq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Tasks. +func (tq *TaskQuery) All(ctx context.Context) ([]*Task, error) { + ctx = setContextOp(ctx, tq.ctx, "All") + if err := tq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Task, *TaskQuery]() + return withInterceptors[[]*Task](ctx, tq, qr, tq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (tq *TaskQuery) AllX(ctx context.Context) []*Task { + nodes, err := tq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Task IDs. +func (tq *TaskQuery) IDs(ctx context.Context) (ids []int, err error) { + if tq.ctx.Unique == nil && tq.path != nil { + tq.Unique(true) + } + ctx = setContextOp(ctx, tq.ctx, "IDs") + if err = tq.Select(task.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (tq *TaskQuery) IDsX(ctx context.Context) []int { + ids, err := tq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (tq *TaskQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, tq.ctx, "Count") + if err := tq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, tq, querierCount[*TaskQuery](), tq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (tq *TaskQuery) CountX(ctx context.Context) int { + count, err := tq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (tq *TaskQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, tq.ctx, "Exist") + switch _, err := tq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (tq *TaskQuery) ExistX(ctx context.Context) bool { + exist, err := tq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the TaskQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (tq *TaskQuery) Clone() *TaskQuery { + if tq == nil { + return nil + } + return &TaskQuery{ + config: tq.config, + ctx: tq.ctx.Clone(), + order: append([]task.OrderOption{}, tq.order...), + inters: append([]Interceptor{}, tq.inters...), + predicates: append([]predicate.Task{}, tq.predicates...), + withUser: tq.withUser.Clone(), + // clone intermediate query. + sql: tq.sql.Clone(), + path: tq.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (tq *TaskQuery) WithUser(opts ...func(*UserQuery)) *TaskQuery { + query := (&UserClient{config: tq.config}).Query() + for _, opt := range opts { + opt(query) + } + tq.withUser = query + return tq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Task.Query(). +// GroupBy(task.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (tq *TaskQuery) GroupBy(field string, fields ...string) *TaskGroupBy { + tq.ctx.Fields = append([]string{field}, fields...) + grbuild := &TaskGroupBy{build: tq} + grbuild.flds = &tq.ctx.Fields + grbuild.label = task.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.Task.Query(). +// Select(task.FieldCreatedAt). +// Scan(ctx, &v) +func (tq *TaskQuery) Select(fields ...string) *TaskSelect { + tq.ctx.Fields = append(tq.ctx.Fields, fields...) + sbuild := &TaskSelect{TaskQuery: tq} + sbuild.label = task.Label + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a TaskSelect configured with the given aggregations. +func (tq *TaskQuery) Aggregate(fns ...AggregateFunc) *TaskSelect { + return tq.Select().Aggregate(fns...) +} + +func (tq *TaskQuery) prepareQuery(ctx context.Context) error { + for _, inter := range tq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, tq); err != nil { + return err + } + } + } + for _, f := range tq.ctx.Fields { + if !task.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if tq.path != nil { + prev, err := tq.path(ctx) + if err != nil { + return err + } + tq.sql = prev + } + return nil +} + +func (tq *TaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Task, error) { + var ( + nodes = []*Task{} + _spec = tq.querySpec() + loadedTypes = [1]bool{ + tq.withUser != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Task).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Task{config: tq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := tq.withUser; query != nil { + if err := tq.loadUser(ctx, query, nodes, nil, + func(n *Task, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (tq *TaskQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Task, init func(*Task), assign func(*Task, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Task) + for i := range nodes { + fk := nodes[i].UserTasks + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_tasks" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (tq *TaskQuery) sqlCount(ctx context.Context) (int, error) { + _spec := tq.querySpec() + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, tq.driver, _spec) +} + +func (tq *TaskQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(task.Table, task.Columns, sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt)) + _spec.From = tq.sql + if unique := tq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if tq.path != nil { + _spec.Unique = true + } + if fields := tq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, task.FieldID) + for i := range fields { + if fields[i] != task.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if tq.withUser != nil { + _spec.Node.AddColumnOnce(task.FieldUserTasks) + } + } + if ps := tq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := tq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := tq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := tq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (tq *TaskQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(tq.driver.Dialect()) + t1 := builder.Table(task.Table) + columns := tq.ctx.Fields + if len(columns) == 0 { + columns = task.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if tq.sql != nil { + selector = tq.sql + selector.Select(selector.Columns(columns...)...) + } + if tq.ctx.Unique != nil && *tq.ctx.Unique { + selector.Distinct() + } + for _, p := range tq.predicates { + p(selector) + } + for _, p := range tq.order { + p(selector) + } + if offset := tq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := tq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// TaskGroupBy is the group-by builder for Task entities. +type TaskGroupBy struct { + selector + build *TaskQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (tgb *TaskGroupBy) Aggregate(fns ...AggregateFunc) *TaskGroupBy { + tgb.fns = append(tgb.fns, fns...) + return tgb +} + +// Scan applies the selector query and scans the result into the given value. +func (tgb *TaskGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") + if err := tgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*TaskQuery, *TaskGroupBy](ctx, tgb.build, tgb, tgb.build.inters, v) +} + +func (tgb *TaskGroupBy) sqlScan(ctx context.Context, root *TaskQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(tgb.fns)) + for _, fn := range tgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*tgb.flds)+len(tgb.fns)) + for _, f := range *tgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*tgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := tgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// TaskSelect is the builder for selecting fields of Task entities. +type TaskSelect struct { + *TaskQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ts *TaskSelect) Aggregate(fns ...AggregateFunc) *TaskSelect { + ts.fns = append(ts.fns, fns...) + return ts +} + +// Scan applies the selector query and scans the result into the given value. +func (ts *TaskSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ts.ctx, "Select") + if err := ts.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*TaskQuery, *TaskSelect](ctx, ts.TaskQuery, ts, ts.inters, v) +} + +func (ts *TaskSelect) sqlScan(ctx context.Context, root *TaskQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ts.fns)) + for _, fn := range ts.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ts.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ts.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/task_update.go b/ent/task_update.go new file mode 100644 index 00000000..9bf5452f --- /dev/null +++ b/ent/task_update.go @@ -0,0 +1,596 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// TaskUpdate is the builder for updating Task entities. +type TaskUpdate struct { + config + hooks []Hook + mutation *TaskMutation +} + +// Where appends a list predicates to the TaskUpdate builder. +func (tu *TaskUpdate) Where(ps ...predicate.Task) *TaskUpdate { + tu.mutation.Where(ps...) + return tu +} + +// SetUpdatedAt sets the "updated_at" field. +func (tu *TaskUpdate) SetUpdatedAt(t time.Time) *TaskUpdate { + tu.mutation.SetUpdatedAt(t) + return tu +} + +// SetDeletedAt sets the "deleted_at" field. +func (tu *TaskUpdate) SetDeletedAt(t time.Time) *TaskUpdate { + tu.mutation.SetDeletedAt(t) + return tu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableDeletedAt(t *time.Time) *TaskUpdate { + if t != nil { + tu.SetDeletedAt(*t) + } + return tu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (tu *TaskUpdate) ClearDeletedAt() *TaskUpdate { + tu.mutation.ClearDeletedAt() + return tu +} + +// SetType sets the "type" field. +func (tu *TaskUpdate) SetType(s string) *TaskUpdate { + tu.mutation.SetType(s) + return tu +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableType(s *string) *TaskUpdate { + if s != nil { + tu.SetType(*s) + } + return tu +} + +// SetStatus sets the "status" field. +func (tu *TaskUpdate) SetStatus(t task.Status) *TaskUpdate { + tu.mutation.SetStatus(t) + return tu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableStatus(t *task.Status) *TaskUpdate { + if t != nil { + tu.SetStatus(*t) + } + return tu +} + +// SetPublicState sets the "public_state" field. +func (tu *TaskUpdate) SetPublicState(tps *types.TaskPublicState) *TaskUpdate { + tu.mutation.SetPublicState(tps) + return tu +} + +// SetPrivateState sets the "private_state" field. +func (tu *TaskUpdate) SetPrivateState(s string) *TaskUpdate { + tu.mutation.SetPrivateState(s) + return tu +} + +// SetNillablePrivateState sets the "private_state" field if the given value is not nil. +func (tu *TaskUpdate) SetNillablePrivateState(s *string) *TaskUpdate { + if s != nil { + tu.SetPrivateState(*s) + } + return tu +} + +// ClearPrivateState clears the value of the "private_state" field. +func (tu *TaskUpdate) ClearPrivateState() *TaskUpdate { + tu.mutation.ClearPrivateState() + return tu +} + +// SetUserTasks sets the "user_tasks" field. +func (tu *TaskUpdate) SetUserTasks(i int) *TaskUpdate { + tu.mutation.SetUserTasks(i) + return tu +} + +// SetNillableUserTasks sets the "user_tasks" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableUserTasks(i *int) *TaskUpdate { + if i != nil { + tu.SetUserTasks(*i) + } + return tu +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (tu *TaskUpdate) ClearUserTasks() *TaskUpdate { + tu.mutation.ClearUserTasks() + return tu +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (tu *TaskUpdate) SetUserID(id int) *TaskUpdate { + tu.mutation.SetUserID(id) + return tu +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (tu *TaskUpdate) SetNillableUserID(id *int) *TaskUpdate { + if id != nil { + tu = tu.SetUserID(*id) + } + return tu +} + +// SetUser sets the "user" edge to the User entity. +func (tu *TaskUpdate) SetUser(u *User) *TaskUpdate { + return tu.SetUserID(u.ID) +} + +// Mutation returns the TaskMutation object of the builder. +func (tu *TaskUpdate) Mutation() *TaskMutation { + return tu.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (tu *TaskUpdate) ClearUser() *TaskUpdate { + tu.mutation.ClearUser() + return tu +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (tu *TaskUpdate) Save(ctx context.Context) (int, error) { + if err := tu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, tu.sqlSave, tu.mutation, tu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (tu *TaskUpdate) SaveX(ctx context.Context) int { + affected, err := tu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (tu *TaskUpdate) Exec(ctx context.Context) error { + _, err := tu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (tu *TaskUpdate) ExecX(ctx context.Context) { + if err := tu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (tu *TaskUpdate) defaults() error { + if _, ok := tu.mutation.UpdatedAt(); !ok { + if task.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized task.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := task.UpdateDefaultUpdatedAt() + tu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (tu *TaskUpdate) check() error { + if v, ok := tu.mutation.Status(); ok { + if err := task.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Task.status": %w`, err)} + } + } + return nil +} + +func (tu *TaskUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := tu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(task.Table, task.Columns, sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt)) + if ps := tu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := tu.mutation.UpdatedAt(); ok { + _spec.SetField(task.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := tu.mutation.DeletedAt(); ok { + _spec.SetField(task.FieldDeletedAt, field.TypeTime, value) + } + if tu.mutation.DeletedAtCleared() { + _spec.ClearField(task.FieldDeletedAt, field.TypeTime) + } + if value, ok := tu.mutation.GetType(); ok { + _spec.SetField(task.FieldType, field.TypeString, value) + } + if value, ok := tu.mutation.Status(); ok { + _spec.SetField(task.FieldStatus, field.TypeEnum, value) + } + if value, ok := tu.mutation.PublicState(); ok { + _spec.SetField(task.FieldPublicState, field.TypeJSON, value) + } + if value, ok := tu.mutation.PrivateState(); ok { + _spec.SetField(task.FieldPrivateState, field.TypeString, value) + } + if tu.mutation.PrivateStateCleared() { + _spec.ClearField(task.FieldPrivateState, field.TypeString) + } + if tu.mutation.CorrelationIDCleared() { + _spec.ClearField(task.FieldCorrelationID, field.TypeUUID) + } + if tu.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: task.UserTable, + Columns: []string{task.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := tu.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: task.UserTable, + Columns: []string{task.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, tu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{task.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + tu.mutation.done = true + return n, nil +} + +// TaskUpdateOne is the builder for updating a single Task entity. +type TaskUpdateOne struct { + config + fields []string + hooks []Hook + mutation *TaskMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (tuo *TaskUpdateOne) SetUpdatedAt(t time.Time) *TaskUpdateOne { + tuo.mutation.SetUpdatedAt(t) + return tuo +} + +// SetDeletedAt sets the "deleted_at" field. +func (tuo *TaskUpdateOne) SetDeletedAt(t time.Time) *TaskUpdateOne { + tuo.mutation.SetDeletedAt(t) + return tuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableDeletedAt(t *time.Time) *TaskUpdateOne { + if t != nil { + tuo.SetDeletedAt(*t) + } + return tuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (tuo *TaskUpdateOne) ClearDeletedAt() *TaskUpdateOne { + tuo.mutation.ClearDeletedAt() + return tuo +} + +// SetType sets the "type" field. +func (tuo *TaskUpdateOne) SetType(s string) *TaskUpdateOne { + tuo.mutation.SetType(s) + return tuo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableType(s *string) *TaskUpdateOne { + if s != nil { + tuo.SetType(*s) + } + return tuo +} + +// SetStatus sets the "status" field. +func (tuo *TaskUpdateOne) SetStatus(t task.Status) *TaskUpdateOne { + tuo.mutation.SetStatus(t) + return tuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableStatus(t *task.Status) *TaskUpdateOne { + if t != nil { + tuo.SetStatus(*t) + } + return tuo +} + +// SetPublicState sets the "public_state" field. +func (tuo *TaskUpdateOne) SetPublicState(tps *types.TaskPublicState) *TaskUpdateOne { + tuo.mutation.SetPublicState(tps) + return tuo +} + +// SetPrivateState sets the "private_state" field. +func (tuo *TaskUpdateOne) SetPrivateState(s string) *TaskUpdateOne { + tuo.mutation.SetPrivateState(s) + return tuo +} + +// SetNillablePrivateState sets the "private_state" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillablePrivateState(s *string) *TaskUpdateOne { + if s != nil { + tuo.SetPrivateState(*s) + } + return tuo +} + +// ClearPrivateState clears the value of the "private_state" field. +func (tuo *TaskUpdateOne) ClearPrivateState() *TaskUpdateOne { + tuo.mutation.ClearPrivateState() + return tuo +} + +// SetUserTasks sets the "user_tasks" field. +func (tuo *TaskUpdateOne) SetUserTasks(i int) *TaskUpdateOne { + tuo.mutation.SetUserTasks(i) + return tuo +} + +// SetNillableUserTasks sets the "user_tasks" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableUserTasks(i *int) *TaskUpdateOne { + if i != nil { + tuo.SetUserTasks(*i) + } + return tuo +} + +// ClearUserTasks clears the value of the "user_tasks" field. +func (tuo *TaskUpdateOne) ClearUserTasks() *TaskUpdateOne { + tuo.mutation.ClearUserTasks() + return tuo +} + +// SetUserID sets the "user" edge to the User entity by ID. +func (tuo *TaskUpdateOne) SetUserID(id int) *TaskUpdateOne { + tuo.mutation.SetUserID(id) + return tuo +} + +// SetNillableUserID sets the "user" edge to the User entity by ID if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableUserID(id *int) *TaskUpdateOne { + if id != nil { + tuo = tuo.SetUserID(*id) + } + return tuo +} + +// SetUser sets the "user" edge to the User entity. +func (tuo *TaskUpdateOne) SetUser(u *User) *TaskUpdateOne { + return tuo.SetUserID(u.ID) +} + +// Mutation returns the TaskMutation object of the builder. +func (tuo *TaskUpdateOne) Mutation() *TaskMutation { + return tuo.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (tuo *TaskUpdateOne) ClearUser() *TaskUpdateOne { + tuo.mutation.ClearUser() + return tuo +} + +// Where appends a list predicates to the TaskUpdate builder. +func (tuo *TaskUpdateOne) Where(ps ...predicate.Task) *TaskUpdateOne { + tuo.mutation.Where(ps...) + return tuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (tuo *TaskUpdateOne) Select(field string, fields ...string) *TaskUpdateOne { + tuo.fields = append([]string{field}, fields...) + return tuo +} + +// Save executes the query and returns the updated Task entity. +func (tuo *TaskUpdateOne) Save(ctx context.Context) (*Task, error) { + if err := tuo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, tuo.sqlSave, tuo.mutation, tuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (tuo *TaskUpdateOne) SaveX(ctx context.Context) *Task { + node, err := tuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (tuo *TaskUpdateOne) Exec(ctx context.Context) error { + _, err := tuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (tuo *TaskUpdateOne) ExecX(ctx context.Context) { + if err := tuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (tuo *TaskUpdateOne) defaults() error { + if _, ok := tuo.mutation.UpdatedAt(); !ok { + if task.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized task.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := task.UpdateDefaultUpdatedAt() + tuo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (tuo *TaskUpdateOne) check() error { + if v, ok := tuo.mutation.Status(); ok { + if err := task.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Task.status": %w`, err)} + } + } + return nil +} + +func (tuo *TaskUpdateOne) sqlSave(ctx context.Context) (_node *Task, err error) { + if err := tuo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(task.Table, task.Columns, sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt)) + id, ok := tuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Task.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := tuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, task.FieldID) + for _, f := range fields { + if !task.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != task.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := tuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := tuo.mutation.UpdatedAt(); ok { + _spec.SetField(task.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := tuo.mutation.DeletedAt(); ok { + _spec.SetField(task.FieldDeletedAt, field.TypeTime, value) + } + if tuo.mutation.DeletedAtCleared() { + _spec.ClearField(task.FieldDeletedAt, field.TypeTime) + } + if value, ok := tuo.mutation.GetType(); ok { + _spec.SetField(task.FieldType, field.TypeString, value) + } + if value, ok := tuo.mutation.Status(); ok { + _spec.SetField(task.FieldStatus, field.TypeEnum, value) + } + if value, ok := tuo.mutation.PublicState(); ok { + _spec.SetField(task.FieldPublicState, field.TypeJSON, value) + } + if value, ok := tuo.mutation.PrivateState(); ok { + _spec.SetField(task.FieldPrivateState, field.TypeString, value) + } + if tuo.mutation.PrivateStateCleared() { + _spec.ClearField(task.FieldPrivateState, field.TypeString) + } + if tuo.mutation.CorrelationIDCleared() { + _spec.ClearField(task.FieldCorrelationID, field.TypeUUID) + } + if tuo.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: task.UserTable, + Columns: []string{task.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := tuo.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: task.UserTable, + Columns: []string{task.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Task{config: tuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, tuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{task.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + tuo.mutation.done = true + return _node, nil +} diff --git a/ent/templates/createhelper.tmpl b/ent/templates/createhelper.tmpl new file mode 100644 index 00000000..ac4ed368 --- /dev/null +++ b/ent/templates/createhelper.tmpl @@ -0,0 +1,30 @@ +{{/* The line below tells Intellij/GoLand to enable the autocompletion based *gen.Type type. */}} +{{/* gotype: entgo.io/ent/entc/gen.Type */}} + +{{ define "create/additional/createhelper" }} + +{{/* A template that adds the "GoString" method to all generated models on the same file they are defined. */}} + +{{ if $.HasOneFieldID }} + +func (m *{{ $.Name }}Create) SetRawID(t {{ $.ID.Type }}) *{{ $.Name }}Create { + m.mutation.SetRawID(t) + return m +} + +{{ end }} + +{{ end }} + +{{ define "dialect/sql/create/spec/createhelper" }} +{{ $receiver := $.Scope.Receiver }} +{{ $mutation := print $receiver ".mutation" }} +{{- if not $.HasCompositeID}} + if id, ok := {{ $mutation }}.{{ $.ID.MutationGet }}(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = {{ if and $.ID.Type.ValueScanner (not $.ID.Type.RType.IsPtr) }}&{{ end }}id64 + } +{{- end }} + +{{ end }} \ No newline at end of file diff --git a/ent/templates/edgehelper.tmpl b/ent/templates/edgehelper.tmpl new file mode 100644 index 00000000..87a7d2f0 --- /dev/null +++ b/ent/templates/edgehelper.tmpl @@ -0,0 +1,20 @@ +{{/* The line below tells Intellij/GoLand to enable the autocompletion based *gen.Type type. */}} +{{/* gotype: entgo.io/ent/entc/gen.Type */}} + +{{ define "model/additional/edgehelper" }} + +{{/* A template that adds the "GoString" method to all generated models on the same file they are defined. */}} + +{{- with $.Edges }} + +{{- range $i, $e := . }} + // Set{{ $e.StructField }} manually set the edge as loaded state. + func (e *{{ $.Name }}) Set{{ $e.StructField }}(v {{ if not $e.Unique }}[]{{ end }}*{{ $e.Type.Name }}) { + e.Edges.{{ $e.StructField }} = v + e.Edges.loadedTypes[{{ $i }}] = true + } +{{- end }} +{{- end }} + + +{{ end }} \ No newline at end of file diff --git a/ent/templates/mutationhelper.tmpl b/ent/templates/mutationhelper.tmpl new file mode 100644 index 00000000..a94f4f5c --- /dev/null +++ b/ent/templates/mutationhelper.tmpl @@ -0,0 +1,20 @@ +{{/* The line below tells Intellij/GoLand to enable the autocompletion based on the *gen.Graph type. */}} +{{/* gotype: entgo.io/ent/entc/gen.Graph */}} + +{{ define "mutationhelper" }} + +{{ $pkg := base $.Config.Package }} +{{ template "header" $ }} + +{{ range $n := $.Nodes }} +// SetUpdatedAt sets the "updated_at" field. +{{ with $n.HasOneFieldID }} + +func (m *{{ $n.Name }}Mutation) SetRawID(t {{ $n.ID.Type }}) { + m.id = &t +} + +{{ end }} +{{ end }} + +{{ end }} \ No newline at end of file diff --git a/ent/tx.go b/ent/tx.go new file mode 100644 index 00000000..98f9b540 --- /dev/null +++ b/ent/tx.go @@ -0,0 +1,272 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + stdsql "database/sql" + "fmt" + "sync" + + "entgo.io/ent/dialect" +) + +// Tx is a transactional client that is created by calling Client.Tx(). +type Tx struct { + config + // DavAccount is the client for interacting with the DavAccount builders. + DavAccount *DavAccountClient + // DirectLink is the client for interacting with the DirectLink builders. + DirectLink *DirectLinkClient + // Entity is the client for interacting with the Entity builders. + Entity *EntityClient + // File is the client for interacting with the File builders. + File *FileClient + // Group is the client for interacting with the Group builders. + Group *GroupClient + // Metadata is the client for interacting with the Metadata builders. + Metadata *MetadataClient + // Node is the client for interacting with the Node builders. + Node *NodeClient + // Passkey is the client for interacting with the Passkey builders. + Passkey *PasskeyClient + // Setting is the client for interacting with the Setting builders. + Setting *SettingClient + // Share is the client for interacting with the Share builders. + Share *ShareClient + // StoragePolicy is the client for interacting with the StoragePolicy builders. + StoragePolicy *StoragePolicyClient + // Task is the client for interacting with the Task builders. + Task *TaskClient + // User is the client for interacting with the User builders. + User *UserClient + + // lazily loaded. + client *Client + clientOnce sync.Once + // ctx lives for the life of the transaction. It is + // the same context used by the underlying connection. + ctx context.Context +} + +type ( + // Committer is the interface that wraps the Commit method. + Committer interface { + Commit(context.Context, *Tx) error + } + + // The CommitFunc type is an adapter to allow the use of ordinary + // function as a Committer. If f is a function with the appropriate + // signature, CommitFunc(f) is a Committer that calls f. + CommitFunc func(context.Context, *Tx) error + + // CommitHook defines the "commit middleware". A function that gets a Committer + // and returns a Committer. For example: + // + // hook := func(next ent.Committer) ent.Committer { + // return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error { + // // Do some stuff before. + // if err := next.Commit(ctx, tx); err != nil { + // return err + // } + // // Do some stuff after. + // return nil + // }) + // } + // + CommitHook func(Committer) Committer +) + +// Commit calls f(ctx, m). +func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error { + return f(ctx, tx) +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + txDriver := tx.config.driver.(*txDriver) + var fn Committer = CommitFunc(func(context.Context, *Tx) error { + return txDriver.tx.Commit() + }) + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() + for i := len(hooks) - 1; i >= 0; i-- { + fn = hooks[i](fn) + } + return fn.Commit(tx.ctx, tx) +} + +// OnCommit adds a hook to call on commit. +func (tx *Tx) OnCommit(f CommitHook) { + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() +} + +type ( + // Rollbacker is the interface that wraps the Rollback method. + Rollbacker interface { + Rollback(context.Context, *Tx) error + } + + // The RollbackFunc type is an adapter to allow the use of ordinary + // function as a Rollbacker. If f is a function with the appropriate + // signature, RollbackFunc(f) is a Rollbacker that calls f. + RollbackFunc func(context.Context, *Tx) error + + // RollbackHook defines the "rollback middleware". A function that gets a Rollbacker + // and returns a Rollbacker. For example: + // + // hook := func(next ent.Rollbacker) ent.Rollbacker { + // return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error { + // // Do some stuff before. + // if err := next.Rollback(ctx, tx); err != nil { + // return err + // } + // // Do some stuff after. + // return nil + // }) + // } + // + RollbackHook func(Rollbacker) Rollbacker +) + +// Rollback calls f(ctx, m). +func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error { + return f(ctx, tx) +} + +// Rollback rollbacks the transaction. +func (tx *Tx) Rollback() error { + txDriver := tx.config.driver.(*txDriver) + var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { + return txDriver.tx.Rollback() + }) + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() + for i := len(hooks) - 1; i >= 0; i-- { + fn = hooks[i](fn) + } + return fn.Rollback(tx.ctx, tx) +} + +// OnRollback adds a hook to call on rollback. +func (tx *Tx) OnRollback(f RollbackHook) { + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() +} + +// Client returns a Client that binds to current transaction. +func (tx *Tx) Client() *Client { + tx.clientOnce.Do(func() { + tx.client = &Client{config: tx.config} + tx.client.init() + }) + return tx.client +} + +func (tx *Tx) init() { + tx.DavAccount = NewDavAccountClient(tx.config) + tx.DirectLink = NewDirectLinkClient(tx.config) + tx.Entity = NewEntityClient(tx.config) + tx.File = NewFileClient(tx.config) + tx.Group = NewGroupClient(tx.config) + tx.Metadata = NewMetadataClient(tx.config) + tx.Node = NewNodeClient(tx.config) + tx.Passkey = NewPasskeyClient(tx.config) + tx.Setting = NewSettingClient(tx.config) + tx.Share = NewShareClient(tx.config) + tx.StoragePolicy = NewStoragePolicyClient(tx.config) + tx.Task = NewTaskClient(tx.config) + tx.User = NewUserClient(tx.config) +} + +// txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. +// The idea is to support transactions without adding any extra code to the builders. +// When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. +// Commit and Rollback are nop for the internal builders and the user must call one +// of them in order to commit or rollback the transaction. +// +// If a closed transaction is embedded in one of the generated entities, and the entity +// applies a query, for example: DavAccount.QueryXXX(), the query will be executed +// through the driver which created this transaction. +// +// Note that txDriver is not goroutine safe. +type txDriver struct { + // the driver we started the transaction from. + drv dialect.Driver + // tx is the underlying transaction. + tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook +} + +// newTx creates a new transactional driver. +func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { + tx, err := drv.Tx(ctx) + if err != nil { + return nil, err + } + return &txDriver{tx: tx, drv: drv}, nil +} + +// Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls +// from the internal builders. Should be called only by the internal builders. +func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } + +// Dialect returns the dialect of the driver we started the transaction from. +func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } + +// Close is a nop close. +func (*txDriver) Close() error { return nil } + +// Commit is a nop commit for the internal builders. +// User must call `Tx.Commit` in order to commit the transaction. +func (*txDriver) Commit() error { return nil } + +// Rollback is a nop rollback for the internal builders. +// User must call `Tx.Rollback` in order to rollback the transaction. +func (*txDriver) Rollback() error { return nil } + +// Exec calls tx.Exec. +func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { + return tx.tx.Exec(ctx, query, args, v) +} + +// Query calls tx.Query. +func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { + return tx.tx.Query(ctx, query, args, v) +} + +var _ dialect.Driver = (*txDriver)(nil) + +// ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it. +// See, database/sql#Tx.ExecContext for more information. +func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { + ex, ok := tx.tx.(interface { + ExecContext(context.Context, string, ...any) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) +} + +// QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it. +// See, database/sql#Tx.QueryContext for more information. +func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { + q, ok := tx.tx.(interface { + QueryContext(context.Context, string, ...any) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) +} diff --git a/ent/user.go b/ent/user.go new file mode 100644 index 00000000..a530b102 --- /dev/null +++ b/ent/user.go @@ -0,0 +1,413 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// User is the model entity for the User schema. +type User struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // Email holds the value of the "email" field. + Email string `json:"email,omitempty"` + // Nick holds the value of the "nick" field. + Nick string `json:"nick,omitempty"` + // Password holds the value of the "password" field. + Password string `json:"-"` + // Status holds the value of the "status" field. + Status user.Status `json:"status,omitempty"` + // Storage holds the value of the "storage" field. + Storage int64 `json:"storage,omitempty"` + // TwoFactorSecret holds the value of the "two_factor_secret" field. + TwoFactorSecret string `json:"-"` + // Avatar holds the value of the "avatar" field. + Avatar string `json:"avatar,omitempty"` + // Settings holds the value of the "settings" field. + Settings *types.UserSetting `json:"settings,omitempty"` + // GroupUsers holds the value of the "group_users" field. + GroupUsers int `json:"group_users,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the UserQuery when eager-loading is set. + Edges UserEdges `json:"edges"` + storage_policy_users *int + selectValues sql.SelectValues +} + +// UserEdges holds the relations/edges for other nodes in the graph. +type UserEdges struct { + // Group holds the value of the group edge. + Group *Group `json:"group,omitempty"` + // Files holds the value of the files edge. + Files []*File `json:"files,omitempty"` + // DavAccounts holds the value of the dav_accounts edge. + DavAccounts []*DavAccount `json:"dav_accounts,omitempty"` + // Shares holds the value of the shares edge. + Shares []*Share `json:"shares,omitempty"` + // Passkey holds the value of the passkey edge. + Passkey []*Passkey `json:"passkey,omitempty"` + // Tasks holds the value of the tasks edge. + Tasks []*Task `json:"tasks,omitempty"` + // Entities holds the value of the entities edge. + Entities []*Entity `json:"entities,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [7]bool +} + +// GroupOrErr returns the Group value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UserEdges) GroupOrErr() (*Group, error) { + if e.loadedTypes[0] { + if e.Group == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: group.Label} + } + return e.Group, nil + } + return nil, &NotLoadedError{edge: "group"} +} + +// FilesOrErr returns the Files value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) FilesOrErr() ([]*File, error) { + if e.loadedTypes[1] { + return e.Files, nil + } + return nil, &NotLoadedError{edge: "files"} +} + +// DavAccountsOrErr returns the DavAccounts value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) DavAccountsOrErr() ([]*DavAccount, error) { + if e.loadedTypes[2] { + return e.DavAccounts, nil + } + return nil, &NotLoadedError{edge: "dav_accounts"} +} + +// SharesOrErr returns the Shares value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) SharesOrErr() ([]*Share, error) { + if e.loadedTypes[3] { + return e.Shares, nil + } + return nil, &NotLoadedError{edge: "shares"} +} + +// PasskeyOrErr returns the Passkey value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PasskeyOrErr() ([]*Passkey, error) { + if e.loadedTypes[4] { + return e.Passkey, nil + } + return nil, &NotLoadedError{edge: "passkey"} +} + +// TasksOrErr returns the Tasks value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) TasksOrErr() ([]*Task, error) { + if e.loadedTypes[5] { + return e.Tasks, nil + } + return nil, &NotLoadedError{edge: "tasks"} +} + +// EntitiesOrErr returns the Entities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) EntitiesOrErr() ([]*Entity, error) { + if e.loadedTypes[6] { + return e.Entities, nil + } + return nil, &NotLoadedError{edge: "entities"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*User) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case user.FieldSettings: + values[i] = new([]byte) + case user.FieldID, user.FieldStorage, user.FieldGroupUsers: + values[i] = new(sql.NullInt64) + case user.FieldEmail, user.FieldNick, user.FieldPassword, user.FieldStatus, user.FieldTwoFactorSecret, user.FieldAvatar: + values[i] = new(sql.NullString) + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt: + values[i] = new(sql.NullTime) + case user.ForeignKeys[0]: // storage_policy_users + values[i] = new(sql.NullInt64) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the User fields. +func (u *User) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case user.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + u.ID = int(value.Int64) + case user.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + u.CreatedAt = value.Time + } + case user.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + u.UpdatedAt = value.Time + } + case user.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + u.DeletedAt = new(time.Time) + *u.DeletedAt = value.Time + } + case user.FieldEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field email", values[i]) + } else if value.Valid { + u.Email = value.String + } + case user.FieldNick: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field nick", values[i]) + } else if value.Valid { + u.Nick = value.String + } + case user.FieldPassword: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field password", values[i]) + } else if value.Valid { + u.Password = value.String + } + case user.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + u.Status = user.Status(value.String) + } + case user.FieldStorage: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field storage", values[i]) + } else if value.Valid { + u.Storage = value.Int64 + } + case user.FieldTwoFactorSecret: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field two_factor_secret", values[i]) + } else if value.Valid { + u.TwoFactorSecret = value.String + } + case user.FieldAvatar: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field avatar", values[i]) + } else if value.Valid { + u.Avatar = value.String + } + case user.FieldSettings: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field settings", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &u.Settings); err != nil { + return fmt.Errorf("unmarshal field settings: %w", err) + } + } + case user.FieldGroupUsers: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field group_users", values[i]) + } else if value.Valid { + u.GroupUsers = int(value.Int64) + } + case user.ForeignKeys[0]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field storage_policy_users", value) + } else if value.Valid { + u.storage_policy_users = new(int) + *u.storage_policy_users = int(value.Int64) + } + default: + u.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the User. +// This includes values selected through modifiers, order, etc. +func (u *User) Value(name string) (ent.Value, error) { + return u.selectValues.Get(name) +} + +// QueryGroup queries the "group" edge of the User entity. +func (u *User) QueryGroup() *GroupQuery { + return NewUserClient(u.config).QueryGroup(u) +} + +// QueryFiles queries the "files" edge of the User entity. +func (u *User) QueryFiles() *FileQuery { + return NewUserClient(u.config).QueryFiles(u) +} + +// QueryDavAccounts queries the "dav_accounts" edge of the User entity. +func (u *User) QueryDavAccounts() *DavAccountQuery { + return NewUserClient(u.config).QueryDavAccounts(u) +} + +// QueryShares queries the "shares" edge of the User entity. +func (u *User) QueryShares() *ShareQuery { + return NewUserClient(u.config).QueryShares(u) +} + +// QueryPasskey queries the "passkey" edge of the User entity. +func (u *User) QueryPasskey() *PasskeyQuery { + return NewUserClient(u.config).QueryPasskey(u) +} + +// QueryTasks queries the "tasks" edge of the User entity. +func (u *User) QueryTasks() *TaskQuery { + return NewUserClient(u.config).QueryTasks(u) +} + +// QueryEntities queries the "entities" edge of the User entity. +func (u *User) QueryEntities() *EntityQuery { + return NewUserClient(u.config).QueryEntities(u) +} + +// Update returns a builder for updating this User. +// Note that you need to call User.Unwrap() before calling this method if this User +// was returned from a transaction, and the transaction was committed or rolled back. +func (u *User) Update() *UserUpdateOne { + return NewUserClient(u.config).UpdateOne(u) +} + +// Unwrap unwraps the User entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (u *User) Unwrap() *User { + _tx, ok := u.config.driver.(*txDriver) + if !ok { + panic("ent: User is not a transactional entity") + } + u.config.driver = _tx.drv + return u +} + +// String implements the fmt.Stringer. +func (u *User) String() string { + var builder strings.Builder + builder.WriteString("User(") + builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) + builder.WriteString("created_at=") + builder.WriteString(u.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(u.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := u.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("email=") + builder.WriteString(u.Email) + builder.WriteString(", ") + builder.WriteString("nick=") + builder.WriteString(u.Nick) + builder.WriteString(", ") + builder.WriteString("password=") + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", u.Status)) + builder.WriteString(", ") + builder.WriteString("storage=") + builder.WriteString(fmt.Sprintf("%v", u.Storage)) + builder.WriteString(", ") + builder.WriteString("two_factor_secret=") + builder.WriteString(", ") + builder.WriteString("avatar=") + builder.WriteString(u.Avatar) + builder.WriteString(", ") + builder.WriteString("settings=") + builder.WriteString(fmt.Sprintf("%v", u.Settings)) + builder.WriteString(", ") + builder.WriteString("group_users=") + builder.WriteString(fmt.Sprintf("%v", u.GroupUsers)) + builder.WriteByte(')') + return builder.String() +} + +// SetGroup manually set the edge as loaded state. +func (e *User) SetGroup(v *Group) { + e.Edges.Group = v + e.Edges.loadedTypes[0] = true +} + +// SetFiles manually set the edge as loaded state. +func (e *User) SetFiles(v []*File) { + e.Edges.Files = v + e.Edges.loadedTypes[1] = true +} + +// SetDavAccounts manually set the edge as loaded state. +func (e *User) SetDavAccounts(v []*DavAccount) { + e.Edges.DavAccounts = v + e.Edges.loadedTypes[2] = true +} + +// SetShares manually set the edge as loaded state. +func (e *User) SetShares(v []*Share) { + e.Edges.Shares = v + e.Edges.loadedTypes[3] = true +} + +// SetPasskey manually set the edge as loaded state. +func (e *User) SetPasskey(v []*Passkey) { + e.Edges.Passkey = v + e.Edges.loadedTypes[4] = true +} + +// SetTasks manually set the edge as loaded state. +func (e *User) SetTasks(v []*Task) { + e.Edges.Tasks = v + e.Edges.loadedTypes[5] = true +} + +// SetEntities manually set the edge as loaded state. +func (e *User) SetEntities(v []*Entity) { + e.Edges.Entities = v + e.Edges.loadedTypes[6] = true +} + +// Users is a parsable slice of User. +type Users []*User diff --git a/ent/user/user.go b/ent/user/user.go new file mode 100644 index 00000000..781dc024 --- /dev/null +++ b/ent/user/user.go @@ -0,0 +1,402 @@ +// Code generated by ent, DO NOT EDIT. + +package user + +import ( + "fmt" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +const ( + // Label holds the string label denoting the user type in the database. + Label = "user" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldEmail holds the string denoting the email field in the database. + FieldEmail = "email" + // FieldNick holds the string denoting the nick field in the database. + FieldNick = "nick" + // FieldPassword holds the string denoting the password field in the database. + FieldPassword = "password" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldStorage holds the string denoting the storage field in the database. + FieldStorage = "storage" + // FieldTwoFactorSecret holds the string denoting the two_factor_secret field in the database. + FieldTwoFactorSecret = "two_factor_secret" + // FieldAvatar holds the string denoting the avatar field in the database. + FieldAvatar = "avatar" + // FieldSettings holds the string denoting the settings field in the database. + FieldSettings = "settings" + // FieldGroupUsers holds the string denoting the group_users field in the database. + FieldGroupUsers = "group_users" + // EdgeGroup holds the string denoting the group edge name in mutations. + EdgeGroup = "group" + // EdgeFiles holds the string denoting the files edge name in mutations. + EdgeFiles = "files" + // EdgeDavAccounts holds the string denoting the dav_accounts edge name in mutations. + EdgeDavAccounts = "dav_accounts" + // EdgeShares holds the string denoting the shares edge name in mutations. + EdgeShares = "shares" + // EdgePasskey holds the string denoting the passkey edge name in mutations. + EdgePasskey = "passkey" + // EdgeTasks holds the string denoting the tasks edge name in mutations. + EdgeTasks = "tasks" + // EdgeEntities holds the string denoting the entities edge name in mutations. + EdgeEntities = "entities" + // Table holds the table name of the user in the database. + Table = "users" + // GroupTable is the table that holds the group relation/edge. + GroupTable = "users" + // GroupInverseTable is the table name for the Group entity. + // It exists in this package in order to avoid circular dependency with the "group" package. + GroupInverseTable = "groups" + // GroupColumn is the table column denoting the group relation/edge. + GroupColumn = "group_users" + // FilesTable is the table that holds the files relation/edge. + FilesTable = "files" + // FilesInverseTable is the table name for the File entity. + // It exists in this package in order to avoid circular dependency with the "file" package. + FilesInverseTable = "files" + // FilesColumn is the table column denoting the files relation/edge. + FilesColumn = "owner_id" + // DavAccountsTable is the table that holds the dav_accounts relation/edge. + DavAccountsTable = "dav_accounts" + // DavAccountsInverseTable is the table name for the DavAccount entity. + // It exists in this package in order to avoid circular dependency with the "davaccount" package. + DavAccountsInverseTable = "dav_accounts" + // DavAccountsColumn is the table column denoting the dav_accounts relation/edge. + DavAccountsColumn = "owner_id" + // SharesTable is the table that holds the shares relation/edge. + SharesTable = "shares" + // SharesInverseTable is the table name for the Share entity. + // It exists in this package in order to avoid circular dependency with the "share" package. + SharesInverseTable = "shares" + // SharesColumn is the table column denoting the shares relation/edge. + SharesColumn = "user_shares" + // PasskeyTable is the table that holds the passkey relation/edge. + PasskeyTable = "passkeys" + // PasskeyInverseTable is the table name for the Passkey entity. + // It exists in this package in order to avoid circular dependency with the "passkey" package. + PasskeyInverseTable = "passkeys" + // PasskeyColumn is the table column denoting the passkey relation/edge. + PasskeyColumn = "user_id" + // TasksTable is the table that holds the tasks relation/edge. + TasksTable = "tasks" + // TasksInverseTable is the table name for the Task entity. + // It exists in this package in order to avoid circular dependency with the "task" package. + TasksInverseTable = "tasks" + // TasksColumn is the table column denoting the tasks relation/edge. + TasksColumn = "user_tasks" + // EntitiesTable is the table that holds the entities relation/edge. + EntitiesTable = "entities" + // EntitiesInverseTable is the table name for the Entity entity. + // It exists in this package in order to avoid circular dependency with the "entity" package. + EntitiesInverseTable = "entities" + // EntitiesColumn is the table column denoting the entities relation/edge. + EntitiesColumn = "created_by" +) + +// Columns holds all SQL columns for user fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldEmail, + FieldNick, + FieldPassword, + FieldStatus, + FieldStorage, + FieldTwoFactorSecret, + FieldAvatar, + FieldSettings, + FieldGroupUsers, +} + +// ForeignKeys holds the SQL foreign-keys that are owned by the "users" +// table and are not defined as standalone fields in the schema. +var ForeignKeys = []string{ + "storage_policy_users", +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + for i := range ForeignKeys { + if column == ForeignKeys[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // EmailValidator is a validator for the "email" field. It is called by the builders before save. + EmailValidator func(string) error + // NickValidator is a validator for the "nick" field. It is called by the builders before save. + NickValidator func(string) error + // DefaultStorage holds the default value on creation for the "storage" field. + DefaultStorage int64 + // DefaultSettings holds the default value on creation for the "settings" field. + DefaultSettings *types.UserSetting +) + +// Status defines the type for the "status" enum field. +type Status string + +// StatusActive is the default value of the Status enum. +const DefaultStatus = StatusActive + +// Status values. +const ( + StatusActive Status = "active" + StatusInactive Status = "inactive" + StatusManualBanned Status = "manual_banned" + StatusSysBanned Status = "sys_banned" +) + +func (s Status) String() string { + return string(s) +} + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s Status) error { + switch s { + case StatusActive, StatusInactive, StatusManualBanned, StatusSysBanned: + return nil + default: + return fmt.Errorf("user: invalid enum value for status field: %q", s) + } +} + +// OrderOption defines the ordering options for the User queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByEmail orders the results by the email field. +func ByEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmail, opts...).ToFunc() +} + +// ByNick orders the results by the nick field. +func ByNick(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNick, opts...).ToFunc() +} + +// ByPassword orders the results by the password field. +func ByPassword(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassword, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByStorage orders the results by the storage field. +func ByStorage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStorage, opts...).ToFunc() +} + +// ByTwoFactorSecret orders the results by the two_factor_secret field. +func ByTwoFactorSecret(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTwoFactorSecret, opts...).ToFunc() +} + +// ByAvatar orders the results by the avatar field. +func ByAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAvatar, opts...).ToFunc() +} + +// ByGroupUsers orders the results by the group_users field. +func ByGroupUsers(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGroupUsers, opts...).ToFunc() +} + +// ByGroupField orders the results by group field. +func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) + } +} + +// ByFilesCount orders the results by files count. +func ByFilesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newFilesStep(), opts...) + } +} + +// ByFiles orders the results by files terms. +func ByFiles(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newFilesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByDavAccountsCount orders the results by dav_accounts count. +func ByDavAccountsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDavAccountsStep(), opts...) + } +} + +// ByDavAccounts orders the results by dav_accounts terms. +func ByDavAccounts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDavAccountsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// BySharesCount orders the results by shares count. +func BySharesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSharesStep(), opts...) + } +} + +// ByShares orders the results by shares terms. +func ByShares(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSharesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPasskeyCount orders the results by passkey count. +func ByPasskeyCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPasskeyStep(), opts...) + } +} + +// ByPasskey orders the results by passkey terms. +func ByPasskey(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPasskeyStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByTasksCount orders the results by tasks count. +func ByTasksCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newTasksStep(), opts...) + } +} + +// ByTasks orders the results by tasks terms. +func ByTasks(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTasksStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByEntitiesCount orders the results by entities count. +func ByEntitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newEntitiesStep(), opts...) + } +} + +// ByEntities orders the results by entities terms. +func ByEntities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newEntitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newGroupStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(GroupInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) +} +func newFilesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(FilesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, FilesTable, FilesColumn), + ) +} +func newDavAccountsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DavAccountsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DavAccountsTable, DavAccountsColumn), + ) +} +func newSharesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SharesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SharesTable, SharesColumn), + ) +} +func newPasskeyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PasskeyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PasskeyTable, PasskeyColumn), + ) +} +func newTasksStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TasksInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, TasksTable, TasksColumn), + ) +} +func newEntitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(EntitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EntitiesTable, EntitiesColumn), + ) +} diff --git a/ent/user/where.go b/ent/user/where.go new file mode 100644 index 00000000..5400662a --- /dev/null +++ b/ent/user/where.go @@ -0,0 +1,857 @@ +// Code generated by ent, DO NOT EDIT. + +package user + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.User { + return predicate.User(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.User { + return predicate.User(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.User { + return predicate.User(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.User { + return predicate.User(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.User { + return predicate.User(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.User { + return predicate.User(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.User { + return predicate.User(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.User { + return predicate.User(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.User { + return predicate.User(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldDeletedAt, v)) +} + +// Email applies equality check predicate on the "email" field. It's identical to EmailEQ. +func Email(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldEmail, v)) +} + +// Nick applies equality check predicate on the "nick" field. It's identical to NickEQ. +func Nick(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldNick, v)) +} + +// Password applies equality check predicate on the "password" field. It's identical to PasswordEQ. +func Password(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldPassword, v)) +} + +// Storage applies equality check predicate on the "storage" field. It's identical to StorageEQ. +func Storage(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldStorage, v)) +} + +// TwoFactorSecret applies equality check predicate on the "two_factor_secret" field. It's identical to TwoFactorSecretEQ. +func TwoFactorSecret(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldTwoFactorSecret, v)) +} + +// Avatar applies equality check predicate on the "avatar" field. It's identical to AvatarEQ. +func Avatar(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldAvatar, v)) +} + +// GroupUsers applies equality check predicate on the "group_users" field. It's identical to GroupUsersEQ. +func GroupUsers(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldGroupUsers, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldDeletedAt)) +} + +// EmailEQ applies the EQ predicate on the "email" field. +func EmailEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldEmail, v)) +} + +// EmailNEQ applies the NEQ predicate on the "email" field. +func EmailNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldEmail, v)) +} + +// EmailIn applies the In predicate on the "email" field. +func EmailIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldEmail, vs...)) +} + +// EmailNotIn applies the NotIn predicate on the "email" field. +func EmailNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldEmail, vs...)) +} + +// EmailGT applies the GT predicate on the "email" field. +func EmailGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldEmail, v)) +} + +// EmailGTE applies the GTE predicate on the "email" field. +func EmailGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldEmail, v)) +} + +// EmailLT applies the LT predicate on the "email" field. +func EmailLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldEmail, v)) +} + +// EmailLTE applies the LTE predicate on the "email" field. +func EmailLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldEmail, v)) +} + +// EmailContains applies the Contains predicate on the "email" field. +func EmailContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldEmail, v)) +} + +// EmailHasPrefix applies the HasPrefix predicate on the "email" field. +func EmailHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldEmail, v)) +} + +// EmailHasSuffix applies the HasSuffix predicate on the "email" field. +func EmailHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldEmail, v)) +} + +// EmailEqualFold applies the EqualFold predicate on the "email" field. +func EmailEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldEmail, v)) +} + +// EmailContainsFold applies the ContainsFold predicate on the "email" field. +func EmailContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldEmail, v)) +} + +// NickEQ applies the EQ predicate on the "nick" field. +func NickEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldNick, v)) +} + +// NickNEQ applies the NEQ predicate on the "nick" field. +func NickNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldNick, v)) +} + +// NickIn applies the In predicate on the "nick" field. +func NickIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldNick, vs...)) +} + +// NickNotIn applies the NotIn predicate on the "nick" field. +func NickNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldNick, vs...)) +} + +// NickGT applies the GT predicate on the "nick" field. +func NickGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldNick, v)) +} + +// NickGTE applies the GTE predicate on the "nick" field. +func NickGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldNick, v)) +} + +// NickLT applies the LT predicate on the "nick" field. +func NickLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldNick, v)) +} + +// NickLTE applies the LTE predicate on the "nick" field. +func NickLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldNick, v)) +} + +// NickContains applies the Contains predicate on the "nick" field. +func NickContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldNick, v)) +} + +// NickHasPrefix applies the HasPrefix predicate on the "nick" field. +func NickHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldNick, v)) +} + +// NickHasSuffix applies the HasSuffix predicate on the "nick" field. +func NickHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldNick, v)) +} + +// NickEqualFold applies the EqualFold predicate on the "nick" field. +func NickEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldNick, v)) +} + +// NickContainsFold applies the ContainsFold predicate on the "nick" field. +func NickContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldNick, v)) +} + +// PasswordEQ applies the EQ predicate on the "password" field. +func PasswordEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldPassword, v)) +} + +// PasswordNEQ applies the NEQ predicate on the "password" field. +func PasswordNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldPassword, v)) +} + +// PasswordIn applies the In predicate on the "password" field. +func PasswordIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldPassword, vs...)) +} + +// PasswordNotIn applies the NotIn predicate on the "password" field. +func PasswordNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldPassword, vs...)) +} + +// PasswordGT applies the GT predicate on the "password" field. +func PasswordGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldPassword, v)) +} + +// PasswordGTE applies the GTE predicate on the "password" field. +func PasswordGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldPassword, v)) +} + +// PasswordLT applies the LT predicate on the "password" field. +func PasswordLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldPassword, v)) +} + +// PasswordLTE applies the LTE predicate on the "password" field. +func PasswordLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldPassword, v)) +} + +// PasswordContains applies the Contains predicate on the "password" field. +func PasswordContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldPassword, v)) +} + +// PasswordHasPrefix applies the HasPrefix predicate on the "password" field. +func PasswordHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldPassword, v)) +} + +// PasswordHasSuffix applies the HasSuffix predicate on the "password" field. +func PasswordHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldPassword, v)) +} + +// PasswordIsNil applies the IsNil predicate on the "password" field. +func PasswordIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldPassword)) +} + +// PasswordNotNil applies the NotNil predicate on the "password" field. +func PasswordNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldPassword)) +} + +// PasswordEqualFold applies the EqualFold predicate on the "password" field. +func PasswordEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldPassword, v)) +} + +// PasswordContainsFold applies the ContainsFold predicate on the "password" field. +func PasswordContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldPassword, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v Status) predicate.User { + return predicate.User(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v Status) predicate.User { + return predicate.User(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...Status) predicate.User { + return predicate.User(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...Status) predicate.User { + return predicate.User(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StorageEQ applies the EQ predicate on the "storage" field. +func StorageEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldStorage, v)) +} + +// StorageNEQ applies the NEQ predicate on the "storage" field. +func StorageNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldStorage, v)) +} + +// StorageIn applies the In predicate on the "storage" field. +func StorageIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldStorage, vs...)) +} + +// StorageNotIn applies the NotIn predicate on the "storage" field. +func StorageNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldStorage, vs...)) +} + +// StorageGT applies the GT predicate on the "storage" field. +func StorageGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldStorage, v)) +} + +// StorageGTE applies the GTE predicate on the "storage" field. +func StorageGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldStorage, v)) +} + +// StorageLT applies the LT predicate on the "storage" field. +func StorageLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldStorage, v)) +} + +// StorageLTE applies the LTE predicate on the "storage" field. +func StorageLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldStorage, v)) +} + +// TwoFactorSecretEQ applies the EQ predicate on the "two_factor_secret" field. +func TwoFactorSecretEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretNEQ applies the NEQ predicate on the "two_factor_secret" field. +func TwoFactorSecretNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretIn applies the In predicate on the "two_factor_secret" field. +func TwoFactorSecretIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldTwoFactorSecret, vs...)) +} + +// TwoFactorSecretNotIn applies the NotIn predicate on the "two_factor_secret" field. +func TwoFactorSecretNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldTwoFactorSecret, vs...)) +} + +// TwoFactorSecretGT applies the GT predicate on the "two_factor_secret" field. +func TwoFactorSecretGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretGTE applies the GTE predicate on the "two_factor_secret" field. +func TwoFactorSecretGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretLT applies the LT predicate on the "two_factor_secret" field. +func TwoFactorSecretLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretLTE applies the LTE predicate on the "two_factor_secret" field. +func TwoFactorSecretLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretContains applies the Contains predicate on the "two_factor_secret" field. +func TwoFactorSecretContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretHasPrefix applies the HasPrefix predicate on the "two_factor_secret" field. +func TwoFactorSecretHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretHasSuffix applies the HasSuffix predicate on the "two_factor_secret" field. +func TwoFactorSecretHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretIsNil applies the IsNil predicate on the "two_factor_secret" field. +func TwoFactorSecretIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldTwoFactorSecret)) +} + +// TwoFactorSecretNotNil applies the NotNil predicate on the "two_factor_secret" field. +func TwoFactorSecretNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldTwoFactorSecret)) +} + +// TwoFactorSecretEqualFold applies the EqualFold predicate on the "two_factor_secret" field. +func TwoFactorSecretEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldTwoFactorSecret, v)) +} + +// TwoFactorSecretContainsFold applies the ContainsFold predicate on the "two_factor_secret" field. +func TwoFactorSecretContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldTwoFactorSecret, v)) +} + +// AvatarEQ applies the EQ predicate on the "avatar" field. +func AvatarEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldAvatar, v)) +} + +// AvatarNEQ applies the NEQ predicate on the "avatar" field. +func AvatarNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldAvatar, v)) +} + +// AvatarIn applies the In predicate on the "avatar" field. +func AvatarIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldAvatar, vs...)) +} + +// AvatarNotIn applies the NotIn predicate on the "avatar" field. +func AvatarNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldAvatar, vs...)) +} + +// AvatarGT applies the GT predicate on the "avatar" field. +func AvatarGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldAvatar, v)) +} + +// AvatarGTE applies the GTE predicate on the "avatar" field. +func AvatarGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldAvatar, v)) +} + +// AvatarLT applies the LT predicate on the "avatar" field. +func AvatarLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldAvatar, v)) +} + +// AvatarLTE applies the LTE predicate on the "avatar" field. +func AvatarLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldAvatar, v)) +} + +// AvatarContains applies the Contains predicate on the "avatar" field. +func AvatarContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldAvatar, v)) +} + +// AvatarHasPrefix applies the HasPrefix predicate on the "avatar" field. +func AvatarHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldAvatar, v)) +} + +// AvatarHasSuffix applies the HasSuffix predicate on the "avatar" field. +func AvatarHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldAvatar, v)) +} + +// AvatarIsNil applies the IsNil predicate on the "avatar" field. +func AvatarIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldAvatar)) +} + +// AvatarNotNil applies the NotNil predicate on the "avatar" field. +func AvatarNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldAvatar)) +} + +// AvatarEqualFold applies the EqualFold predicate on the "avatar" field. +func AvatarEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldAvatar, v)) +} + +// AvatarContainsFold applies the ContainsFold predicate on the "avatar" field. +func AvatarContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldAvatar, v)) +} + +// SettingsIsNil applies the IsNil predicate on the "settings" field. +func SettingsIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldSettings)) +} + +// SettingsNotNil applies the NotNil predicate on the "settings" field. +func SettingsNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldSettings)) +} + +// GroupUsersEQ applies the EQ predicate on the "group_users" field. +func GroupUsersEQ(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldGroupUsers, v)) +} + +// GroupUsersNEQ applies the NEQ predicate on the "group_users" field. +func GroupUsersNEQ(v int) predicate.User { + return predicate.User(sql.FieldNEQ(FieldGroupUsers, v)) +} + +// GroupUsersIn applies the In predicate on the "group_users" field. +func GroupUsersIn(vs ...int) predicate.User { + return predicate.User(sql.FieldIn(FieldGroupUsers, vs...)) +} + +// GroupUsersNotIn applies the NotIn predicate on the "group_users" field. +func GroupUsersNotIn(vs ...int) predicate.User { + return predicate.User(sql.FieldNotIn(FieldGroupUsers, vs...)) +} + +// HasGroup applies the HasEdge predicate on the "group" edge. +func HasGroup() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasGroupWith applies the HasEdge predicate on the "group" edge with a given conditions (other predicates). +func HasGroupWith(preds ...predicate.Group) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newGroupStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasFiles applies the HasEdge predicate on the "files" edge. +func HasFiles() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, FilesTable, FilesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasFilesWith applies the HasEdge predicate on the "files" edge with a given conditions (other predicates). +func HasFilesWith(preds ...predicate.File) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newFilesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasDavAccounts applies the HasEdge predicate on the "dav_accounts" edge. +func HasDavAccounts() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DavAccountsTable, DavAccountsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasDavAccountsWith applies the HasEdge predicate on the "dav_accounts" edge with a given conditions (other predicates). +func HasDavAccountsWith(preds ...predicate.DavAccount) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newDavAccountsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasShares applies the HasEdge predicate on the "shares" edge. +func HasShares() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SharesTable, SharesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasSharesWith applies the HasEdge predicate on the "shares" edge with a given conditions (other predicates). +func HasSharesWith(preds ...predicate.Share) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newSharesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPasskey applies the HasEdge predicate on the "passkey" edge. +func HasPasskey() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PasskeyTable, PasskeyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPasskeyWith applies the HasEdge predicate on the "passkey" edge with a given conditions (other predicates). +func HasPasskeyWith(preds ...predicate.Passkey) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPasskeyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasTasks applies the HasEdge predicate on the "tasks" edge. +func HasTasks() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, TasksTable, TasksColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTasksWith applies the HasEdge predicate on the "tasks" edge with a given conditions (other predicates). +func HasTasksWith(preds ...predicate.Task) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newTasksStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasEntities applies the HasEdge predicate on the "entities" edge. +func HasEntities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, EntitiesTable, EntitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasEntitiesWith applies the HasEdge predicate on the "entities" edge with a given conditions (other predicates). +func HasEntitiesWith(preds ...predicate.Entity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newEntitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.User) predicate.User { + return predicate.User(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.User) predicate.User { + return predicate.User(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.User) predicate.User { + return predicate.User(sql.NotPredicates(p)) +} diff --git a/ent/user_create.go b/ent/user_create.go new file mode 100644 index 00000000..4e5f9278 --- /dev/null +++ b/ent/user_create.go @@ -0,0 +1,1462 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// UserCreate is the builder for creating a User entity. +type UserCreate struct { + config + mutation *UserMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (uc *UserCreate) SetCreatedAt(t time.Time) *UserCreate { + uc.mutation.SetCreatedAt(t) + return uc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (uc *UserCreate) SetNillableCreatedAt(t *time.Time) *UserCreate { + if t != nil { + uc.SetCreatedAt(*t) + } + return uc +} + +// SetUpdatedAt sets the "updated_at" field. +func (uc *UserCreate) SetUpdatedAt(t time.Time) *UserCreate { + uc.mutation.SetUpdatedAt(t) + return uc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (uc *UserCreate) SetNillableUpdatedAt(t *time.Time) *UserCreate { + if t != nil { + uc.SetUpdatedAt(*t) + } + return uc +} + +// SetDeletedAt sets the "deleted_at" field. +func (uc *UserCreate) SetDeletedAt(t time.Time) *UserCreate { + uc.mutation.SetDeletedAt(t) + return uc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uc *UserCreate) SetNillableDeletedAt(t *time.Time) *UserCreate { + if t != nil { + uc.SetDeletedAt(*t) + } + return uc +} + +// SetEmail sets the "email" field. +func (uc *UserCreate) SetEmail(s string) *UserCreate { + uc.mutation.SetEmail(s) + return uc +} + +// SetNick sets the "nick" field. +func (uc *UserCreate) SetNick(s string) *UserCreate { + uc.mutation.SetNick(s) + return uc +} + +// SetPassword sets the "password" field. +func (uc *UserCreate) SetPassword(s string) *UserCreate { + uc.mutation.SetPassword(s) + return uc +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (uc *UserCreate) SetNillablePassword(s *string) *UserCreate { + if s != nil { + uc.SetPassword(*s) + } + return uc +} + +// SetStatus sets the "status" field. +func (uc *UserCreate) SetStatus(u user.Status) *UserCreate { + uc.mutation.SetStatus(u) + return uc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uc *UserCreate) SetNillableStatus(u *user.Status) *UserCreate { + if u != nil { + uc.SetStatus(*u) + } + return uc +} + +// SetStorage sets the "storage" field. +func (uc *UserCreate) SetStorage(i int64) *UserCreate { + uc.mutation.SetStorage(i) + return uc +} + +// SetNillableStorage sets the "storage" field if the given value is not nil. +func (uc *UserCreate) SetNillableStorage(i *int64) *UserCreate { + if i != nil { + uc.SetStorage(*i) + } + return uc +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (uc *UserCreate) SetTwoFactorSecret(s string) *UserCreate { + uc.mutation.SetTwoFactorSecret(s) + return uc +} + +// SetNillableTwoFactorSecret sets the "two_factor_secret" field if the given value is not nil. +func (uc *UserCreate) SetNillableTwoFactorSecret(s *string) *UserCreate { + if s != nil { + uc.SetTwoFactorSecret(*s) + } + return uc +} + +// SetAvatar sets the "avatar" field. +func (uc *UserCreate) SetAvatar(s string) *UserCreate { + uc.mutation.SetAvatar(s) + return uc +} + +// SetNillableAvatar sets the "avatar" field if the given value is not nil. +func (uc *UserCreate) SetNillableAvatar(s *string) *UserCreate { + if s != nil { + uc.SetAvatar(*s) + } + return uc +} + +// SetSettings sets the "settings" field. +func (uc *UserCreate) SetSettings(ts *types.UserSetting) *UserCreate { + uc.mutation.SetSettings(ts) + return uc +} + +// SetGroupUsers sets the "group_users" field. +func (uc *UserCreate) SetGroupUsers(i int) *UserCreate { + uc.mutation.SetGroupUsers(i) + return uc +} + +// SetGroupID sets the "group" edge to the Group entity by ID. +func (uc *UserCreate) SetGroupID(id int) *UserCreate { + uc.mutation.SetGroupID(id) + return uc +} + +// SetGroup sets the "group" edge to the Group entity. +func (uc *UserCreate) SetGroup(g *Group) *UserCreate { + return uc.SetGroupID(g.ID) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (uc *UserCreate) AddFileIDs(ids ...int) *UserCreate { + uc.mutation.AddFileIDs(ids...) + return uc +} + +// AddFiles adds the "files" edges to the File entity. +func (uc *UserCreate) AddFiles(f ...*File) *UserCreate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return uc.AddFileIDs(ids...) +} + +// AddDavAccountIDs adds the "dav_accounts" edge to the DavAccount entity by IDs. +func (uc *UserCreate) AddDavAccountIDs(ids ...int) *UserCreate { + uc.mutation.AddDavAccountIDs(ids...) + return uc +} + +// AddDavAccounts adds the "dav_accounts" edges to the DavAccount entity. +func (uc *UserCreate) AddDavAccounts(d ...*DavAccount) *UserCreate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return uc.AddDavAccountIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (uc *UserCreate) AddShareIDs(ids ...int) *UserCreate { + uc.mutation.AddShareIDs(ids...) + return uc +} + +// AddShares adds the "shares" edges to the Share entity. +func (uc *UserCreate) AddShares(s ...*Share) *UserCreate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return uc.AddShareIDs(ids...) +} + +// AddPasskeyIDs adds the "passkey" edge to the Passkey entity by IDs. +func (uc *UserCreate) AddPasskeyIDs(ids ...int) *UserCreate { + uc.mutation.AddPasskeyIDs(ids...) + return uc +} + +// AddPasskey adds the "passkey" edges to the Passkey entity. +func (uc *UserCreate) AddPasskey(p ...*Passkey) *UserCreate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uc.AddPasskeyIDs(ids...) +} + +// AddTaskIDs adds the "tasks" edge to the Task entity by IDs. +func (uc *UserCreate) AddTaskIDs(ids ...int) *UserCreate { + uc.mutation.AddTaskIDs(ids...) + return uc +} + +// AddTasks adds the "tasks" edges to the Task entity. +func (uc *UserCreate) AddTasks(t ...*Task) *UserCreate { + ids := make([]int, len(t)) + for i := range t { + ids[i] = t[i].ID + } + return uc.AddTaskIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (uc *UserCreate) AddEntityIDs(ids ...int) *UserCreate { + uc.mutation.AddEntityIDs(ids...) + return uc +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (uc *UserCreate) AddEntities(e ...*Entity) *UserCreate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return uc.AddEntityIDs(ids...) +} + +// Mutation returns the UserMutation object of the builder. +func (uc *UserCreate) Mutation() *UserMutation { + return uc.mutation +} + +// Save creates the User in the database. +func (uc *UserCreate) Save(ctx context.Context) (*User, error) { + if err := uc.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, uc.sqlSave, uc.mutation, uc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (uc *UserCreate) SaveX(ctx context.Context) *User { + v, err := uc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (uc *UserCreate) Exec(ctx context.Context) error { + _, err := uc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (uc *UserCreate) ExecX(ctx context.Context) { + if err := uc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (uc *UserCreate) defaults() error { + if _, ok := uc.mutation.CreatedAt(); !ok { + if user.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized user.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := user.DefaultCreatedAt() + uc.mutation.SetCreatedAt(v) + } + if _, ok := uc.mutation.UpdatedAt(); !ok { + if user.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized user.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := user.DefaultUpdatedAt() + uc.mutation.SetUpdatedAt(v) + } + if _, ok := uc.mutation.Status(); !ok { + v := user.DefaultStatus + uc.mutation.SetStatus(v) + } + if _, ok := uc.mutation.Storage(); !ok { + v := user.DefaultStorage + uc.mutation.SetStorage(v) + } + if _, ok := uc.mutation.Settings(); !ok { + v := user.DefaultSettings + uc.mutation.SetSettings(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (uc *UserCreate) check() error { + if _, ok := uc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "User.created_at"`)} + } + if _, ok := uc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "User.updated_at"`)} + } + if _, ok := uc.mutation.Email(); !ok { + return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)} + } + if v, ok := uc.mutation.Email(); ok { + if err := user.EmailValidator(v); err != nil { + return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} + } + } + if _, ok := uc.mutation.Nick(); !ok { + return &ValidationError{Name: "nick", err: errors.New(`ent: missing required field "User.nick"`)} + } + if v, ok := uc.mutation.Nick(); ok { + if err := user.NickValidator(v); err != nil { + return &ValidationError{Name: "nick", err: fmt.Errorf(`ent: validator failed for field "User.nick": %w`, err)} + } + } + if _, ok := uc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "User.status"`)} + } + if v, ok := uc.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } + if _, ok := uc.mutation.Storage(); !ok { + return &ValidationError{Name: "storage", err: errors.New(`ent: missing required field "User.storage"`)} + } + if _, ok := uc.mutation.GroupUsers(); !ok { + return &ValidationError{Name: "group_users", err: errors.New(`ent: missing required field "User.group_users"`)} + } + if _, ok := uc.mutation.GroupID(); !ok { + return &ValidationError{Name: "group", err: errors.New(`ent: missing required edge "User.group"`)} + } + return nil +} + +func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) { + if err := uc.check(); err != nil { + return nil, err + } + _node, _spec := uc.createSpec() + if err := sqlgraph.CreateNode(ctx, uc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + uc.mutation.id = &_node.ID + uc.mutation.done = true + return _node, nil +} + +func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { + var ( + _node = &User{config: uc.config} + _spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) + ) + + if id, ok := uc.mutation.ID(); ok { + _node.ID = id + id64 := int64(id) + _spec.ID.Value = id64 + } + + _spec.OnConflict = uc.conflict + if value, ok := uc.mutation.CreatedAt(); ok { + _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := uc.mutation.UpdatedAt(); ok { + _spec.SetField(user.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := uc.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := uc.mutation.Email(); ok { + _spec.SetField(user.FieldEmail, field.TypeString, value) + _node.Email = value + } + if value, ok := uc.mutation.Nick(); ok { + _spec.SetField(user.FieldNick, field.TypeString, value) + _node.Nick = value + } + if value, ok := uc.mutation.Password(); ok { + _spec.SetField(user.FieldPassword, field.TypeString, value) + _node.Password = value + } + if value, ok := uc.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := uc.mutation.Storage(); ok { + _spec.SetField(user.FieldStorage, field.TypeInt64, value) + _node.Storage = value + } + if value, ok := uc.mutation.TwoFactorSecret(); ok { + _spec.SetField(user.FieldTwoFactorSecret, field.TypeString, value) + _node.TwoFactorSecret = value + } + if value, ok := uc.mutation.Avatar(); ok { + _spec.SetField(user.FieldAvatar, field.TypeString, value) + _node.Avatar = value + } + if value, ok := uc.mutation.Settings(); ok { + _spec.SetField(user.FieldSettings, field.TypeJSON, value) + _node.Settings = value + } + if nodes := uc.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: user.GroupTable, + Columns: []string{user.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.GroupUsers = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.DavAccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.PasskeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.TasksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := uc.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.User.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (uc *UserCreate) OnConflict(opts ...sql.ConflictOption) *UserUpsertOne { + uc.conflict = opts + return &UserUpsertOne{ + create: uc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (uc *UserCreate) OnConflictColumns(columns ...string) *UserUpsertOne { + uc.conflict = append(uc.conflict, sql.ConflictColumns(columns...)) + return &UserUpsertOne{ + create: uc, + } +} + +type ( + // UserUpsertOne is the builder for "upsert"-ing + // one User node. + UserUpsertOne struct { + create *UserCreate + } + + // UserUpsert is the "OnConflict" setter. + UserUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserUpsert) SetUpdatedAt(v time.Time) *UserUpsert { + u.Set(user.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateUpdatedAt() *UserUpsert { + u.SetExcluded(user.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsert) SetDeletedAt(v time.Time) *UserUpsert { + u.Set(user.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateDeletedAt() *UserUpsert { + u.SetExcluded(user.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsert) ClearDeletedAt() *UserUpsert { + u.SetNull(user.FieldDeletedAt) + return u +} + +// SetEmail sets the "email" field. +func (u *UserUpsert) SetEmail(v string) *UserUpsert { + u.Set(user.FieldEmail, v) + return u +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *UserUpsert) UpdateEmail() *UserUpsert { + u.SetExcluded(user.FieldEmail) + return u +} + +// SetNick sets the "nick" field. +func (u *UserUpsert) SetNick(v string) *UserUpsert { + u.Set(user.FieldNick, v) + return u +} + +// UpdateNick sets the "nick" field to the value that was provided on create. +func (u *UserUpsert) UpdateNick() *UserUpsert { + u.SetExcluded(user.FieldNick) + return u +} + +// SetPassword sets the "password" field. +func (u *UserUpsert) SetPassword(v string) *UserUpsert { + u.Set(user.FieldPassword, v) + return u +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *UserUpsert) UpdatePassword() *UserUpsert { + u.SetExcluded(user.FieldPassword) + return u +} + +// ClearPassword clears the value of the "password" field. +func (u *UserUpsert) ClearPassword() *UserUpsert { + u.SetNull(user.FieldPassword) + return u +} + +// SetStatus sets the "status" field. +func (u *UserUpsert) SetStatus(v user.Status) *UserUpsert { + u.Set(user.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsert) UpdateStatus() *UserUpsert { + u.SetExcluded(user.FieldStatus) + return u +} + +// SetStorage sets the "storage" field. +func (u *UserUpsert) SetStorage(v int64) *UserUpsert { + u.Set(user.FieldStorage, v) + return u +} + +// UpdateStorage sets the "storage" field to the value that was provided on create. +func (u *UserUpsert) UpdateStorage() *UserUpsert { + u.SetExcluded(user.FieldStorage) + return u +} + +// AddStorage adds v to the "storage" field. +func (u *UserUpsert) AddStorage(v int64) *UserUpsert { + u.Add(user.FieldStorage, v) + return u +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (u *UserUpsert) SetTwoFactorSecret(v string) *UserUpsert { + u.Set(user.FieldTwoFactorSecret, v) + return u +} + +// UpdateTwoFactorSecret sets the "two_factor_secret" field to the value that was provided on create. +func (u *UserUpsert) UpdateTwoFactorSecret() *UserUpsert { + u.SetExcluded(user.FieldTwoFactorSecret) + return u +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (u *UserUpsert) ClearTwoFactorSecret() *UserUpsert { + u.SetNull(user.FieldTwoFactorSecret) + return u +} + +// SetAvatar sets the "avatar" field. +func (u *UserUpsert) SetAvatar(v string) *UserUpsert { + u.Set(user.FieldAvatar, v) + return u +} + +// UpdateAvatar sets the "avatar" field to the value that was provided on create. +func (u *UserUpsert) UpdateAvatar() *UserUpsert { + u.SetExcluded(user.FieldAvatar) + return u +} + +// ClearAvatar clears the value of the "avatar" field. +func (u *UserUpsert) ClearAvatar() *UserUpsert { + u.SetNull(user.FieldAvatar) + return u +} + +// SetSettings sets the "settings" field. +func (u *UserUpsert) SetSettings(v *types.UserSetting) *UserUpsert { + u.Set(user.FieldSettings, v) + return u +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *UserUpsert) UpdateSettings() *UserUpsert { + u.SetExcluded(user.FieldSettings) + return u +} + +// ClearSettings clears the value of the "settings" field. +func (u *UserUpsert) ClearSettings() *UserUpsert { + u.SetNull(user.FieldSettings) + return u +} + +// SetGroupUsers sets the "group_users" field. +func (u *UserUpsert) SetGroupUsers(v int) *UserUpsert { + u.Set(user.FieldGroupUsers, v) + return u +} + +// UpdateGroupUsers sets the "group_users" field to the value that was provided on create. +func (u *UserUpsert) UpdateGroupUsers() *UserUpsert { + u.SetExcluded(user.FieldGroupUsers) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserUpsertOne) UpdateNewValues() *UserUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(user.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserUpsertOne) Ignore() *UserUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserUpsertOne) DoNothing() *UserUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserCreate.OnConflict +// documentation for more info. +func (u *UserUpsertOne) Update(set func(*UserUpsert)) *UserUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserUpsertOne) SetUpdatedAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateUpdatedAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsertOne) SetDeletedAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateDeletedAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsertOne) ClearDeletedAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearDeletedAt() + }) +} + +// SetEmail sets the "email" field. +func (u *UserUpsertOne) SetEmail(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetEmail(v) + }) +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateEmail() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateEmail() + }) +} + +// SetNick sets the "nick" field. +func (u *UserUpsertOne) SetNick(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetNick(v) + }) +} + +// UpdateNick sets the "nick" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateNick() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateNick() + }) +} + +// SetPassword sets the "password" field. +func (u *UserUpsertOne) SetPassword(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *UserUpsertOne) UpdatePassword() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdatePassword() + }) +} + +// ClearPassword clears the value of the "password" field. +func (u *UserUpsertOne) ClearPassword() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearPassword() + }) +} + +// SetStatus sets the "status" field. +func (u *UserUpsertOne) SetStatus(v user.Status) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateStatus() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateStatus() + }) +} + +// SetStorage sets the "storage" field. +func (u *UserUpsertOne) SetStorage(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetStorage(v) + }) +} + +// AddStorage adds v to the "storage" field. +func (u *UserUpsertOne) AddStorage(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddStorage(v) + }) +} + +// UpdateStorage sets the "storage" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateStorage() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateStorage() + }) +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (u *UserUpsertOne) SetTwoFactorSecret(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTwoFactorSecret(v) + }) +} + +// UpdateTwoFactorSecret sets the "two_factor_secret" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTwoFactorSecret() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTwoFactorSecret() + }) +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (u *UserUpsertOne) ClearTwoFactorSecret() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearTwoFactorSecret() + }) +} + +// SetAvatar sets the "avatar" field. +func (u *UserUpsertOne) SetAvatar(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetAvatar(v) + }) +} + +// UpdateAvatar sets the "avatar" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateAvatar() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateAvatar() + }) +} + +// ClearAvatar clears the value of the "avatar" field. +func (u *UserUpsertOne) ClearAvatar() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearAvatar() + }) +} + +// SetSettings sets the "settings" field. +func (u *UserUpsertOne) SetSettings(v *types.UserSetting) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSettings() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *UserUpsertOne) ClearSettings() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearSettings() + }) +} + +// SetGroupUsers sets the "group_users" field. +func (u *UserUpsertOne) SetGroupUsers(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetGroupUsers(v) + }) +} + +// UpdateGroupUsers sets the "group_users" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateGroupUsers() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateGroupUsers() + }) +} + +// Exec executes the query. +func (u *UserUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UserUpsertOne) ID(ctx context.Context) (id int, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UserUpsertOne) IDX(ctx context.Context) int { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +func (m *UserCreate) SetRawID(t int) *UserCreate { + m.mutation.SetRawID(t) + return m +} + +// UserCreateBulk is the builder for creating many User entities in bulk. +type UserCreateBulk struct { + config + err error + builders []*UserCreate + conflict []sql.ConflictOption +} + +// Save creates the User entities in the database. +func (ucb *UserCreateBulk) Save(ctx context.Context) ([]*User, error) { + if ucb.err != nil { + return nil, ucb.err + } + specs := make([]*sqlgraph.CreateSpec, len(ucb.builders)) + nodes := make([]*User, len(ucb.builders)) + mutators := make([]Mutator, len(ucb.builders)) + for i := range ucb.builders { + func(i int, root context.Context) { + builder := ucb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UserMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, ucb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = ucb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, ucb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, ucb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (ucb *UserCreateBulk) SaveX(ctx context.Context) []*User { + v, err := ucb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (ucb *UserCreateBulk) Exec(ctx context.Context) error { + _, err := ucb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (ucb *UserCreateBulk) ExecX(ctx context.Context) { + if err := ucb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.User.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (ucb *UserCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserUpsertBulk { + ucb.conflict = opts + return &UserUpsertBulk{ + create: ucb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (ucb *UserCreateBulk) OnConflictColumns(columns ...string) *UserUpsertBulk { + ucb.conflict = append(ucb.conflict, sql.ConflictColumns(columns...)) + return &UserUpsertBulk{ + create: ucb, + } +} + +// UserUpsertBulk is the builder for "upsert"-ing +// a bulk of User nodes. +type UserUpsertBulk struct { + create *UserCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserUpsertBulk) UpdateNewValues() *UserUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(user.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.User.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserUpsertBulk) Ignore() *UserUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserUpsertBulk) DoNothing() *UserUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserCreateBulk.OnConflict +// documentation for more info. +func (u *UserUpsertBulk) Update(set func(*UserUpsert)) *UserUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserUpsertBulk) SetUpdatedAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateUpdatedAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsertBulk) SetDeletedAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateDeletedAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsertBulk) ClearDeletedAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearDeletedAt() + }) +} + +// SetEmail sets the "email" field. +func (u *UserUpsertBulk) SetEmail(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetEmail(v) + }) +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateEmail() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateEmail() + }) +} + +// SetNick sets the "nick" field. +func (u *UserUpsertBulk) SetNick(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetNick(v) + }) +} + +// UpdateNick sets the "nick" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateNick() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateNick() + }) +} + +// SetPassword sets the "password" field. +func (u *UserUpsertBulk) SetPassword(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetPassword(v) + }) +} + +// UpdatePassword sets the "password" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdatePassword() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdatePassword() + }) +} + +// ClearPassword clears the value of the "password" field. +func (u *UserUpsertBulk) ClearPassword() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearPassword() + }) +} + +// SetStatus sets the "status" field. +func (u *UserUpsertBulk) SetStatus(v user.Status) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateStatus() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateStatus() + }) +} + +// SetStorage sets the "storage" field. +func (u *UserUpsertBulk) SetStorage(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetStorage(v) + }) +} + +// AddStorage adds v to the "storage" field. +func (u *UserUpsertBulk) AddStorage(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddStorage(v) + }) +} + +// UpdateStorage sets the "storage" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateStorage() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateStorage() + }) +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (u *UserUpsertBulk) SetTwoFactorSecret(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTwoFactorSecret(v) + }) +} + +// UpdateTwoFactorSecret sets the "two_factor_secret" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTwoFactorSecret() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTwoFactorSecret() + }) +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (u *UserUpsertBulk) ClearTwoFactorSecret() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearTwoFactorSecret() + }) +} + +// SetAvatar sets the "avatar" field. +func (u *UserUpsertBulk) SetAvatar(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetAvatar(v) + }) +} + +// UpdateAvatar sets the "avatar" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateAvatar() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateAvatar() + }) +} + +// ClearAvatar clears the value of the "avatar" field. +func (u *UserUpsertBulk) ClearAvatar() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearAvatar() + }) +} + +// SetSettings sets the "settings" field. +func (u *UserUpsertBulk) SetSettings(v *types.UserSetting) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSettings(v) + }) +} + +// UpdateSettings sets the "settings" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSettings() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSettings() + }) +} + +// ClearSettings clears the value of the "settings" field. +func (u *UserUpsertBulk) ClearSettings() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearSettings() + }) +} + +// SetGroupUsers sets the "group_users" field. +func (u *UserUpsertBulk) SetGroupUsers(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetGroupUsers(v) + }) +} + +// UpdateGroupUsers sets the "group_users" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateGroupUsers() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateGroupUsers() + }) +} + +// Exec executes the query. +func (u *UserUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/user_delete.go b/ent/user_delete.go new file mode 100644 index 00000000..94530ffd --- /dev/null +++ b/ent/user_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// UserDelete is the builder for deleting a User entity. +type UserDelete struct { + config + hooks []Hook + mutation *UserMutation +} + +// Where appends a list predicates to the UserDelete builder. +func (ud *UserDelete) Where(ps ...predicate.User) *UserDelete { + ud.mutation.Where(ps...) + return ud +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ud *UserDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ud.sqlExec, ud.mutation, ud.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ud *UserDelete) ExecX(ctx context.Context) int { + n, err := ud.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ud *UserDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) + if ps := ud.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ud.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ud.mutation.done = true + return affected, err +} + +// UserDeleteOne is the builder for deleting a single User entity. +type UserDeleteOne struct { + ud *UserDelete +} + +// Where appends a list predicates to the UserDelete builder. +func (udo *UserDeleteOne) Where(ps ...predicate.User) *UserDeleteOne { + udo.ud.mutation.Where(ps...) + return udo +} + +// Exec executes the deletion query. +func (udo *UserDeleteOne) Exec(ctx context.Context) error { + n, err := udo.ud.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{user.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (udo *UserDeleteOne) ExecX(ctx context.Context) { + if err := udo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/user_query.go b/ent/user_query.go new file mode 100644 index 00000000..6917bdb2 --- /dev/null +++ b/ent/user_query.go @@ -0,0 +1,1056 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" +) + +// UserQuery is the builder for querying User entities. +type UserQuery struct { + config + ctx *QueryContext + order []user.OrderOption + inters []Interceptor + predicates []predicate.User + withGroup *GroupQuery + withFiles *FileQuery + withDavAccounts *DavAccountQuery + withShares *ShareQuery + withPasskey *PasskeyQuery + withTasks *TaskQuery + withEntities *EntityQuery + withFKs bool + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UserQuery builder. +func (uq *UserQuery) Where(ps ...predicate.User) *UserQuery { + uq.predicates = append(uq.predicates, ps...) + return uq +} + +// Limit the number of records to be returned by this query. +func (uq *UserQuery) Limit(limit int) *UserQuery { + uq.ctx.Limit = &limit + return uq +} + +// Offset to start from. +func (uq *UserQuery) Offset(offset int) *UserQuery { + uq.ctx.Offset = &offset + return uq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (uq *UserQuery) Unique(unique bool) *UserQuery { + uq.ctx.Unique = &unique + return uq +} + +// Order specifies how the records should be ordered. +func (uq *UserQuery) Order(o ...user.OrderOption) *UserQuery { + uq.order = append(uq.order, o...) + return uq +} + +// QueryGroup chains the current query on the "group" edge. +func (uq *UserQuery) QueryGroup() *GroupQuery { + query := (&GroupClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, user.GroupTable, user.GroupColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryFiles chains the current query on the "files" edge. +func (uq *UserQuery) QueryFiles() *FileQuery { + query := (&FileClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(file.Table, file.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.FilesTable, user.FilesColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryDavAccounts chains the current query on the "dav_accounts" edge. +func (uq *UserQuery) QueryDavAccounts() *DavAccountQuery { + query := (&DavAccountClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(davaccount.Table, davaccount.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.DavAccountsTable, user.DavAccountsColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryShares chains the current query on the "shares" edge. +func (uq *UserQuery) QueryShares() *ShareQuery { + query := (&ShareClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(share.Table, share.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.SharesTable, user.SharesColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPasskey chains the current query on the "passkey" edge. +func (uq *UserQuery) QueryPasskey() *PasskeyQuery { + query := (&PasskeyClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(passkey.Table, passkey.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PasskeyTable, user.PasskeyColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryTasks chains the current query on the "tasks" edge. +func (uq *UserQuery) QueryTasks() *TaskQuery { + query := (&TaskClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(task.Table, task.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.TasksTable, user.TasksColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryEntities chains the current query on the "entities" edge. +func (uq *UserQuery) QueryEntities() *EntityQuery { + query := (&EntityClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(entity.Table, entity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.EntitiesTable, user.EntitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first User entity from the query. +// Returns a *NotFoundError when no User was found. +func (uq *UserQuery) First(ctx context.Context) (*User, error) { + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{user.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (uq *UserQuery) FirstX(ctx context.Context) *User { + node, err := uq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first User ID from the query. +// Returns a *NotFoundError when no User ID was found. +func (uq *UserQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{user.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (uq *UserQuery) FirstIDX(ctx context.Context) int { + id, err := uq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single User entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one User entity is found. +// Returns a *NotFoundError when no User entities are found. +func (uq *UserQuery) Only(ctx context.Context) (*User, error) { + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{user.Label} + default: + return nil, &NotSingularError{user.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (uq *UserQuery) OnlyX(ctx context.Context) *User { + node, err := uq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only User ID in the query. +// Returns a *NotSingularError when more than one User ID is found. +// Returns a *NotFoundError when no entities are found. +func (uq *UserQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{user.Label} + default: + err = &NotSingularError{user.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (uq *UserQuery) OnlyIDX(ctx context.Context) int { + id, err := uq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Users. +func (uq *UserQuery) All(ctx context.Context) ([]*User, error) { + ctx = setContextOp(ctx, uq.ctx, "All") + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*User, *UserQuery]() + return withInterceptors[[]*User](ctx, uq, qr, uq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (uq *UserQuery) AllX(ctx context.Context) []*User { + nodes, err := uq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of User IDs. +func (uq *UserQuery) IDs(ctx context.Context) (ids []int, err error) { + if uq.ctx.Unique == nil && uq.path != nil { + uq.Unique(true) + } + ctx = setContextOp(ctx, uq.ctx, "IDs") + if err = uq.Select(user.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (uq *UserQuery) IDsX(ctx context.Context) []int { + ids, err := uq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (uq *UserQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, uq.ctx, "Count") + if err := uq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, uq, querierCount[*UserQuery](), uq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (uq *UserQuery) CountX(ctx context.Context) int { + count, err := uq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (uq *UserQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, uq.ctx, "Exist") + switch _, err := uq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (uq *UserQuery) ExistX(ctx context.Context) bool { + exist, err := uq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UserQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (uq *UserQuery) Clone() *UserQuery { + if uq == nil { + return nil + } + return &UserQuery{ + config: uq.config, + ctx: uq.ctx.Clone(), + order: append([]user.OrderOption{}, uq.order...), + inters: append([]Interceptor{}, uq.inters...), + predicates: append([]predicate.User{}, uq.predicates...), + withGroup: uq.withGroup.Clone(), + withFiles: uq.withFiles.Clone(), + withDavAccounts: uq.withDavAccounts.Clone(), + withShares: uq.withShares.Clone(), + withPasskey: uq.withPasskey.Clone(), + withTasks: uq.withTasks.Clone(), + withEntities: uq.withEntities.Clone(), + // clone intermediate query. + sql: uq.sql.Clone(), + path: uq.path, + } +} + +// WithGroup tells the query-builder to eager-load the nodes that are connected to +// the "group" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithGroup(opts ...func(*GroupQuery)) *UserQuery { + query := (&GroupClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withGroup = query + return uq +} + +// WithFiles tells the query-builder to eager-load the nodes that are connected to +// the "files" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithFiles(opts ...func(*FileQuery)) *UserQuery { + query := (&FileClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withFiles = query + return uq +} + +// WithDavAccounts tells the query-builder to eager-load the nodes that are connected to +// the "dav_accounts" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithDavAccounts(opts ...func(*DavAccountQuery)) *UserQuery { + query := (&DavAccountClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withDavAccounts = query + return uq +} + +// WithShares tells the query-builder to eager-load the nodes that are connected to +// the "shares" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithShares(opts ...func(*ShareQuery)) *UserQuery { + query := (&ShareClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withShares = query + return uq +} + +// WithPasskey tells the query-builder to eager-load the nodes that are connected to +// the "passkey" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithPasskey(opts ...func(*PasskeyQuery)) *UserQuery { + query := (&PasskeyClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withPasskey = query + return uq +} + +// WithTasks tells the query-builder to eager-load the nodes that are connected to +// the "tasks" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithTasks(opts ...func(*TaskQuery)) *UserQuery { + query := (&TaskClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withTasks = query + return uq +} + +// WithEntities tells the query-builder to eager-load the nodes that are connected to +// the "entities" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithEntities(opts ...func(*EntityQuery)) *UserQuery { + query := (&EntityClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withEntities = query + return uq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.User.Query(). +// GroupBy(user.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { + uq.ctx.Fields = append([]string{field}, fields...) + grbuild := &UserGroupBy{build: uq} + grbuild.flds = &uq.ctx.Fields + grbuild.label = user.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.User.Query(). +// Select(user.FieldCreatedAt). +// Scan(ctx, &v) +func (uq *UserQuery) Select(fields ...string) *UserSelect { + uq.ctx.Fields = append(uq.ctx.Fields, fields...) + sbuild := &UserSelect{UserQuery: uq} + sbuild.label = user.Label + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UserSelect configured with the given aggregations. +func (uq *UserQuery) Aggregate(fns ...AggregateFunc) *UserSelect { + return uq.Select().Aggregate(fns...) +} + +func (uq *UserQuery) prepareQuery(ctx context.Context) error { + for _, inter := range uq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, uq); err != nil { + return err + } + } + } + for _, f := range uq.ctx.Fields { + if !user.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if uq.path != nil { + prev, err := uq.path(ctx) + if err != nil { + return err + } + uq.sql = prev + } + return nil +} + +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { + var ( + nodes = []*User{} + withFKs = uq.withFKs + _spec = uq.querySpec() + loadedTypes = [7]bool{ + uq.withGroup != nil, + uq.withFiles != nil, + uq.withDavAccounts != nil, + uq.withShares != nil, + uq.withPasskey != nil, + uq.withTasks != nil, + uq.withEntities != nil, + } + ) + if withFKs { + _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) + } + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*User).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &User{config: uq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := uq.withGroup; query != nil { + if err := uq.loadGroup(ctx, query, nodes, nil, + func(n *User, e *Group) { n.Edges.Group = e }); err != nil { + return nil, err + } + } + if query := uq.withFiles; query != nil { + if err := uq.loadFiles(ctx, query, nodes, + func(n *User) { n.Edges.Files = []*File{} }, + func(n *User, e *File) { n.Edges.Files = append(n.Edges.Files, e) }); err != nil { + return nil, err + } + } + if query := uq.withDavAccounts; query != nil { + if err := uq.loadDavAccounts(ctx, query, nodes, + func(n *User) { n.Edges.DavAccounts = []*DavAccount{} }, + func(n *User, e *DavAccount) { n.Edges.DavAccounts = append(n.Edges.DavAccounts, e) }); err != nil { + return nil, err + } + } + if query := uq.withShares; query != nil { + if err := uq.loadShares(ctx, query, nodes, + func(n *User) { n.Edges.Shares = []*Share{} }, + func(n *User, e *Share) { n.Edges.Shares = append(n.Edges.Shares, e) }); err != nil { + return nil, err + } + } + if query := uq.withPasskey; query != nil { + if err := uq.loadPasskey(ctx, query, nodes, + func(n *User) { n.Edges.Passkey = []*Passkey{} }, + func(n *User, e *Passkey) { n.Edges.Passkey = append(n.Edges.Passkey, e) }); err != nil { + return nil, err + } + } + if query := uq.withTasks; query != nil { + if err := uq.loadTasks(ctx, query, nodes, + func(n *User) { n.Edges.Tasks = []*Task{} }, + func(n *User, e *Task) { n.Edges.Tasks = append(n.Edges.Tasks, e) }); err != nil { + return nil, err + } + } + if query := uq.withEntities; query != nil { + if err := uq.loadEntities(ctx, query, nodes, + func(n *User) { n.Edges.Entities = []*Entity{} }, + func(n *User, e *Entity) { n.Edges.Entities = append(n.Edges.Entities, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (uq *UserQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + fk := nodes[i].GroupUsers + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(group.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_users" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (uq *UserQuery) loadFiles(ctx context.Context, query *FileQuery, nodes []*User, init func(*User), assign func(*User, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(file.FieldOwnerID) + } + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.FilesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.OwnerID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "owner_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadDavAccounts(ctx context.Context, query *DavAccountQuery, nodes []*User, init func(*User), assign func(*User, *DavAccount)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(davaccount.FieldOwnerID) + } + query.Where(predicate.DavAccount(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.DavAccountsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.OwnerID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "owner_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadShares(ctx context.Context, query *ShareQuery, nodes []*User, init func(*User), assign func(*User, *Share)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Share(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.SharesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_shares + if fk == nil { + return fmt.Errorf(`foreign-key "user_shares" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_shares" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadPasskey(ctx context.Context, query *PasskeyQuery, nodes []*User, init func(*User), assign func(*User, *Passkey)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(passkey.FieldUserID) + } + query.Where(predicate.Passkey(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PasskeyColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadTasks(ctx context.Context, query *TaskQuery, nodes []*User, init func(*User), assign func(*User, *Task)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(task.FieldUserTasks) + } + query.Where(predicate.Task(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.TasksColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserTasks + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_tasks" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadEntities(ctx context.Context, query *EntityQuery, nodes []*User, init func(*User), assign func(*User, *Entity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(entity.FieldCreatedBy) + } + query.Where(predicate.Entity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.EntitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.CreatedBy + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "created_by" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { + _spec := uq.querySpec() + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, uq.driver, _spec) +} + +func (uq *UserQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) + _spec.From = uq.sql + if unique := uq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if uq.path != nil { + _spec.Unique = true + } + if fields := uq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) + for i := range fields { + if fields[i] != user.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if uq.withGroup != nil { + _spec.Node.AddColumnOnce(user.FieldGroupUsers) + } + } + if ps := uq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := uq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := uq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := uq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (uq *UserQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(uq.driver.Dialect()) + t1 := builder.Table(user.Table) + columns := uq.ctx.Fields + if len(columns) == 0 { + columns = user.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if uq.sql != nil { + selector = uq.sql + selector.Select(selector.Columns(columns...)...) + } + if uq.ctx.Unique != nil && *uq.ctx.Unique { + selector.Distinct() + } + for _, p := range uq.predicates { + p(selector) + } + for _, p := range uq.order { + p(selector) + } + if offset := uq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := uq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// UserGroupBy is the group-by builder for User entities. +type UserGroupBy struct { + selector + build *UserQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (ugb *UserGroupBy) Aggregate(fns ...AggregateFunc) *UserGroupBy { + ugb.fns = append(ugb.fns, fns...) + return ugb +} + +// Scan applies the selector query and scans the result into the given value. +func (ugb *UserGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") + if err := ugb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserQuery, *UserGroupBy](ctx, ugb.build, ugb, ugb.build.inters, v) +} + +func (ugb *UserGroupBy) sqlScan(ctx context.Context, root *UserQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(ugb.fns)) + for _, fn := range ugb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*ugb.flds)+len(ugb.fns)) + for _, f := range *ugb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*ugb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ugb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UserSelect is the builder for selecting fields of User entities. +type UserSelect struct { + *UserQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (us *UserSelect) Aggregate(fns ...AggregateFunc) *UserSelect { + us.fns = append(us.fns, fns...) + return us +} + +// Scan applies the selector query and scans the result into the given value. +func (us *UserSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, us.ctx, "Select") + if err := us.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserQuery, *UserSelect](ctx, us.UserQuery, us, us.inters, v) +} + +func (us *UserSelect) sqlScan(ctx context.Context, root *UserQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(us.fns)) + for _, fn := range us.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*us.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := us.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/user_update.go b/ent/user_update.go new file mode 100644 index 00000000..4d67894a --- /dev/null +++ b/ent/user_update.go @@ -0,0 +1,1776 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" +) + +// UserUpdate is the builder for updating User entities. +type UserUpdate struct { + config + hooks []Hook + mutation *UserMutation +} + +// Where appends a list predicates to the UserUpdate builder. +func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { + uu.mutation.Where(ps...) + return uu +} + +// SetUpdatedAt sets the "updated_at" field. +func (uu *UserUpdate) SetUpdatedAt(t time.Time) *UserUpdate { + uu.mutation.SetUpdatedAt(t) + return uu +} + +// SetDeletedAt sets the "deleted_at" field. +func (uu *UserUpdate) SetDeletedAt(t time.Time) *UserUpdate { + uu.mutation.SetDeletedAt(t) + return uu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDeletedAt(t *time.Time) *UserUpdate { + if t != nil { + uu.SetDeletedAt(*t) + } + return uu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (uu *UserUpdate) ClearDeletedAt() *UserUpdate { + uu.mutation.ClearDeletedAt() + return uu +} + +// SetEmail sets the "email" field. +func (uu *UserUpdate) SetEmail(s string) *UserUpdate { + uu.mutation.SetEmail(s) + return uu +} + +// SetNillableEmail sets the "email" field if the given value is not nil. +func (uu *UserUpdate) SetNillableEmail(s *string) *UserUpdate { + if s != nil { + uu.SetEmail(*s) + } + return uu +} + +// SetNick sets the "nick" field. +func (uu *UserUpdate) SetNick(s string) *UserUpdate { + uu.mutation.SetNick(s) + return uu +} + +// SetNillableNick sets the "nick" field if the given value is not nil. +func (uu *UserUpdate) SetNillableNick(s *string) *UserUpdate { + if s != nil { + uu.SetNick(*s) + } + return uu +} + +// SetPassword sets the "password" field. +func (uu *UserUpdate) SetPassword(s string) *UserUpdate { + uu.mutation.SetPassword(s) + return uu +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (uu *UserUpdate) SetNillablePassword(s *string) *UserUpdate { + if s != nil { + uu.SetPassword(*s) + } + return uu +} + +// ClearPassword clears the value of the "password" field. +func (uu *UserUpdate) ClearPassword() *UserUpdate { + uu.mutation.ClearPassword() + return uu +} + +// SetStatus sets the "status" field. +func (uu *UserUpdate) SetStatus(u user.Status) *UserUpdate { + uu.mutation.SetStatus(u) + return uu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uu *UserUpdate) SetNillableStatus(u *user.Status) *UserUpdate { + if u != nil { + uu.SetStatus(*u) + } + return uu +} + +// SetStorage sets the "storage" field. +func (uu *UserUpdate) SetStorage(i int64) *UserUpdate { + uu.mutation.ResetStorage() + uu.mutation.SetStorage(i) + return uu +} + +// SetNillableStorage sets the "storage" field if the given value is not nil. +func (uu *UserUpdate) SetNillableStorage(i *int64) *UserUpdate { + if i != nil { + uu.SetStorage(*i) + } + return uu +} + +// AddStorage adds i to the "storage" field. +func (uu *UserUpdate) AddStorage(i int64) *UserUpdate { + uu.mutation.AddStorage(i) + return uu +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (uu *UserUpdate) SetTwoFactorSecret(s string) *UserUpdate { + uu.mutation.SetTwoFactorSecret(s) + return uu +} + +// SetNillableTwoFactorSecret sets the "two_factor_secret" field if the given value is not nil. +func (uu *UserUpdate) SetNillableTwoFactorSecret(s *string) *UserUpdate { + if s != nil { + uu.SetTwoFactorSecret(*s) + } + return uu +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (uu *UserUpdate) ClearTwoFactorSecret() *UserUpdate { + uu.mutation.ClearTwoFactorSecret() + return uu +} + +// SetAvatar sets the "avatar" field. +func (uu *UserUpdate) SetAvatar(s string) *UserUpdate { + uu.mutation.SetAvatar(s) + return uu +} + +// SetNillableAvatar sets the "avatar" field if the given value is not nil. +func (uu *UserUpdate) SetNillableAvatar(s *string) *UserUpdate { + if s != nil { + uu.SetAvatar(*s) + } + return uu +} + +// ClearAvatar clears the value of the "avatar" field. +func (uu *UserUpdate) ClearAvatar() *UserUpdate { + uu.mutation.ClearAvatar() + return uu +} + +// SetSettings sets the "settings" field. +func (uu *UserUpdate) SetSettings(ts *types.UserSetting) *UserUpdate { + uu.mutation.SetSettings(ts) + return uu +} + +// ClearSettings clears the value of the "settings" field. +func (uu *UserUpdate) ClearSettings() *UserUpdate { + uu.mutation.ClearSettings() + return uu +} + +// SetGroupUsers sets the "group_users" field. +func (uu *UserUpdate) SetGroupUsers(i int) *UserUpdate { + uu.mutation.SetGroupUsers(i) + return uu +} + +// SetNillableGroupUsers sets the "group_users" field if the given value is not nil. +func (uu *UserUpdate) SetNillableGroupUsers(i *int) *UserUpdate { + if i != nil { + uu.SetGroupUsers(*i) + } + return uu +} + +// SetGroupID sets the "group" edge to the Group entity by ID. +func (uu *UserUpdate) SetGroupID(id int) *UserUpdate { + uu.mutation.SetGroupID(id) + return uu +} + +// SetGroup sets the "group" edge to the Group entity. +func (uu *UserUpdate) SetGroup(g *Group) *UserUpdate { + return uu.SetGroupID(g.ID) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (uu *UserUpdate) AddFileIDs(ids ...int) *UserUpdate { + uu.mutation.AddFileIDs(ids...) + return uu +} + +// AddFiles adds the "files" edges to the File entity. +func (uu *UserUpdate) AddFiles(f ...*File) *UserUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return uu.AddFileIDs(ids...) +} + +// AddDavAccountIDs adds the "dav_accounts" edge to the DavAccount entity by IDs. +func (uu *UserUpdate) AddDavAccountIDs(ids ...int) *UserUpdate { + uu.mutation.AddDavAccountIDs(ids...) + return uu +} + +// AddDavAccounts adds the "dav_accounts" edges to the DavAccount entity. +func (uu *UserUpdate) AddDavAccounts(d ...*DavAccount) *UserUpdate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return uu.AddDavAccountIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (uu *UserUpdate) AddShareIDs(ids ...int) *UserUpdate { + uu.mutation.AddShareIDs(ids...) + return uu +} + +// AddShares adds the "shares" edges to the Share entity. +func (uu *UserUpdate) AddShares(s ...*Share) *UserUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return uu.AddShareIDs(ids...) +} + +// AddPasskeyIDs adds the "passkey" edge to the Passkey entity by IDs. +func (uu *UserUpdate) AddPasskeyIDs(ids ...int) *UserUpdate { + uu.mutation.AddPasskeyIDs(ids...) + return uu +} + +// AddPasskey adds the "passkey" edges to the Passkey entity. +func (uu *UserUpdate) AddPasskey(p ...*Passkey) *UserUpdate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uu.AddPasskeyIDs(ids...) +} + +// AddTaskIDs adds the "tasks" edge to the Task entity by IDs. +func (uu *UserUpdate) AddTaskIDs(ids ...int) *UserUpdate { + uu.mutation.AddTaskIDs(ids...) + return uu +} + +// AddTasks adds the "tasks" edges to the Task entity. +func (uu *UserUpdate) AddTasks(t ...*Task) *UserUpdate { + ids := make([]int, len(t)) + for i := range t { + ids[i] = t[i].ID + } + return uu.AddTaskIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (uu *UserUpdate) AddEntityIDs(ids ...int) *UserUpdate { + uu.mutation.AddEntityIDs(ids...) + return uu +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (uu *UserUpdate) AddEntities(e ...*Entity) *UserUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return uu.AddEntityIDs(ids...) +} + +// Mutation returns the UserMutation object of the builder. +func (uu *UserUpdate) Mutation() *UserMutation { + return uu.mutation +} + +// ClearGroup clears the "group" edge to the Group entity. +func (uu *UserUpdate) ClearGroup() *UserUpdate { + uu.mutation.ClearGroup() + return uu +} + +// ClearFiles clears all "files" edges to the File entity. +func (uu *UserUpdate) ClearFiles() *UserUpdate { + uu.mutation.ClearFiles() + return uu +} + +// RemoveFileIDs removes the "files" edge to File entities by IDs. +func (uu *UserUpdate) RemoveFileIDs(ids ...int) *UserUpdate { + uu.mutation.RemoveFileIDs(ids...) + return uu +} + +// RemoveFiles removes "files" edges to File entities. +func (uu *UserUpdate) RemoveFiles(f ...*File) *UserUpdate { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return uu.RemoveFileIDs(ids...) +} + +// ClearDavAccounts clears all "dav_accounts" edges to the DavAccount entity. +func (uu *UserUpdate) ClearDavAccounts() *UserUpdate { + uu.mutation.ClearDavAccounts() + return uu +} + +// RemoveDavAccountIDs removes the "dav_accounts" edge to DavAccount entities by IDs. +func (uu *UserUpdate) RemoveDavAccountIDs(ids ...int) *UserUpdate { + uu.mutation.RemoveDavAccountIDs(ids...) + return uu +} + +// RemoveDavAccounts removes "dav_accounts" edges to DavAccount entities. +func (uu *UserUpdate) RemoveDavAccounts(d ...*DavAccount) *UserUpdate { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return uu.RemoveDavAccountIDs(ids...) +} + +// ClearShares clears all "shares" edges to the Share entity. +func (uu *UserUpdate) ClearShares() *UserUpdate { + uu.mutation.ClearShares() + return uu +} + +// RemoveShareIDs removes the "shares" edge to Share entities by IDs. +func (uu *UserUpdate) RemoveShareIDs(ids ...int) *UserUpdate { + uu.mutation.RemoveShareIDs(ids...) + return uu +} + +// RemoveShares removes "shares" edges to Share entities. +func (uu *UserUpdate) RemoveShares(s ...*Share) *UserUpdate { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return uu.RemoveShareIDs(ids...) +} + +// ClearPasskey clears all "passkey" edges to the Passkey entity. +func (uu *UserUpdate) ClearPasskey() *UserUpdate { + uu.mutation.ClearPasskey() + return uu +} + +// RemovePasskeyIDs removes the "passkey" edge to Passkey entities by IDs. +func (uu *UserUpdate) RemovePasskeyIDs(ids ...int) *UserUpdate { + uu.mutation.RemovePasskeyIDs(ids...) + return uu +} + +// RemovePasskey removes "passkey" edges to Passkey entities. +func (uu *UserUpdate) RemovePasskey(p ...*Passkey) *UserUpdate { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uu.RemovePasskeyIDs(ids...) +} + +// ClearTasks clears all "tasks" edges to the Task entity. +func (uu *UserUpdate) ClearTasks() *UserUpdate { + uu.mutation.ClearTasks() + return uu +} + +// RemoveTaskIDs removes the "tasks" edge to Task entities by IDs. +func (uu *UserUpdate) RemoveTaskIDs(ids ...int) *UserUpdate { + uu.mutation.RemoveTaskIDs(ids...) + return uu +} + +// RemoveTasks removes "tasks" edges to Task entities. +func (uu *UserUpdate) RemoveTasks(t ...*Task) *UserUpdate { + ids := make([]int, len(t)) + for i := range t { + ids[i] = t[i].ID + } + return uu.RemoveTaskIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (uu *UserUpdate) ClearEntities() *UserUpdate { + uu.mutation.ClearEntities() + return uu +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (uu *UserUpdate) RemoveEntityIDs(ids ...int) *UserUpdate { + uu.mutation.RemoveEntityIDs(ids...) + return uu +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (uu *UserUpdate) RemoveEntities(e ...*Entity) *UserUpdate { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return uu.RemoveEntityIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (uu *UserUpdate) Save(ctx context.Context) (int, error) { + if err := uu.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, uu.sqlSave, uu.mutation, uu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (uu *UserUpdate) SaveX(ctx context.Context) int { + affected, err := uu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (uu *UserUpdate) Exec(ctx context.Context) error { + _, err := uu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (uu *UserUpdate) ExecX(ctx context.Context) { + if err := uu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (uu *UserUpdate) defaults() error { + if _, ok := uu.mutation.UpdatedAt(); !ok { + if user.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized user.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := user.UpdateDefaultUpdatedAt() + uu.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (uu *UserUpdate) check() error { + if v, ok := uu.mutation.Email(); ok { + if err := user.EmailValidator(v); err != nil { + return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} + } + } + if v, ok := uu.mutation.Nick(); ok { + if err := user.NickValidator(v); err != nil { + return &ValidationError{Name: "nick", err: fmt.Errorf(`ent: validator failed for field "User.nick": %w`, err)} + } + } + if v, ok := uu.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } + if _, ok := uu.mutation.GroupID(); uu.mutation.GroupCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "User.group"`) + } + return nil +} + +func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := uu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) + if ps := uu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := uu.mutation.UpdatedAt(); ok { + _spec.SetField(user.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := uu.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + } + if uu.mutation.DeletedAtCleared() { + _spec.ClearField(user.FieldDeletedAt, field.TypeTime) + } + if value, ok := uu.mutation.Email(); ok { + _spec.SetField(user.FieldEmail, field.TypeString, value) + } + if value, ok := uu.mutation.Nick(); ok { + _spec.SetField(user.FieldNick, field.TypeString, value) + } + if value, ok := uu.mutation.Password(); ok { + _spec.SetField(user.FieldPassword, field.TypeString, value) + } + if uu.mutation.PasswordCleared() { + _spec.ClearField(user.FieldPassword, field.TypeString) + } + if value, ok := uu.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + } + if value, ok := uu.mutation.Storage(); ok { + _spec.SetField(user.FieldStorage, field.TypeInt64, value) + } + if value, ok := uu.mutation.AddedStorage(); ok { + _spec.AddField(user.FieldStorage, field.TypeInt64, value) + } + if value, ok := uu.mutation.TwoFactorSecret(); ok { + _spec.SetField(user.FieldTwoFactorSecret, field.TypeString, value) + } + if uu.mutation.TwoFactorSecretCleared() { + _spec.ClearField(user.FieldTwoFactorSecret, field.TypeString) + } + if value, ok := uu.mutation.Avatar(); ok { + _spec.SetField(user.FieldAvatar, field.TypeString, value) + } + if uu.mutation.AvatarCleared() { + _spec.ClearField(user.FieldAvatar, field.TypeString) + } + if value, ok := uu.mutation.Settings(); ok { + _spec.SetField(user.FieldSettings, field.TypeJSON, value) + } + if uu.mutation.SettingsCleared() { + _spec.ClearField(user.FieldSettings, field.TypeJSON) + } + if uu.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: user.GroupTable, + Columns: []string{user.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: user.GroupTable, + Columns: []string{user.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedFilesIDs(); len(nodes) > 0 && !uu.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.DavAccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedDavAccountsIDs(); len(nodes) > 0 && !uu.mutation.DavAccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.DavAccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedSharesIDs(); len(nodes) > 0 && !uu.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.PasskeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedPasskeyIDs(); len(nodes) > 0 && !uu.mutation.PasskeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.PasskeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.TasksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedTasksIDs(); len(nodes) > 0 && !uu.mutation.TasksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.TasksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !uu.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{user.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + uu.mutation.done = true + return n, nil +} + +// UserUpdateOne is the builder for updating a single User entity. +type UserUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UserMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (uuo *UserUpdateOne) SetUpdatedAt(t time.Time) *UserUpdateOne { + uuo.mutation.SetUpdatedAt(t) + return uuo +} + +// SetDeletedAt sets the "deleted_at" field. +func (uuo *UserUpdateOne) SetDeletedAt(t time.Time) *UserUpdateOne { + uuo.mutation.SetDeletedAt(t) + return uuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDeletedAt(t *time.Time) *UserUpdateOne { + if t != nil { + uuo.SetDeletedAt(*t) + } + return uuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (uuo *UserUpdateOne) ClearDeletedAt() *UserUpdateOne { + uuo.mutation.ClearDeletedAt() + return uuo +} + +// SetEmail sets the "email" field. +func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne { + uuo.mutation.SetEmail(s) + return uuo +} + +// SetNillableEmail sets the "email" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableEmail(s *string) *UserUpdateOne { + if s != nil { + uuo.SetEmail(*s) + } + return uuo +} + +// SetNick sets the "nick" field. +func (uuo *UserUpdateOne) SetNick(s string) *UserUpdateOne { + uuo.mutation.SetNick(s) + return uuo +} + +// SetNillableNick sets the "nick" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableNick(s *string) *UserUpdateOne { + if s != nil { + uuo.SetNick(*s) + } + return uuo +} + +// SetPassword sets the "password" field. +func (uuo *UserUpdateOne) SetPassword(s string) *UserUpdateOne { + uuo.mutation.SetPassword(s) + return uuo +} + +// SetNillablePassword sets the "password" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillablePassword(s *string) *UserUpdateOne { + if s != nil { + uuo.SetPassword(*s) + } + return uuo +} + +// ClearPassword clears the value of the "password" field. +func (uuo *UserUpdateOne) ClearPassword() *UserUpdateOne { + uuo.mutation.ClearPassword() + return uuo +} + +// SetStatus sets the "status" field. +func (uuo *UserUpdateOne) SetStatus(u user.Status) *UserUpdateOne { + uuo.mutation.SetStatus(u) + return uuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableStatus(u *user.Status) *UserUpdateOne { + if u != nil { + uuo.SetStatus(*u) + } + return uuo +} + +// SetStorage sets the "storage" field. +func (uuo *UserUpdateOne) SetStorage(i int64) *UserUpdateOne { + uuo.mutation.ResetStorage() + uuo.mutation.SetStorage(i) + return uuo +} + +// SetNillableStorage sets the "storage" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableStorage(i *int64) *UserUpdateOne { + if i != nil { + uuo.SetStorage(*i) + } + return uuo +} + +// AddStorage adds i to the "storage" field. +func (uuo *UserUpdateOne) AddStorage(i int64) *UserUpdateOne { + uuo.mutation.AddStorage(i) + return uuo +} + +// SetTwoFactorSecret sets the "two_factor_secret" field. +func (uuo *UserUpdateOne) SetTwoFactorSecret(s string) *UserUpdateOne { + uuo.mutation.SetTwoFactorSecret(s) + return uuo +} + +// SetNillableTwoFactorSecret sets the "two_factor_secret" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableTwoFactorSecret(s *string) *UserUpdateOne { + if s != nil { + uuo.SetTwoFactorSecret(*s) + } + return uuo +} + +// ClearTwoFactorSecret clears the value of the "two_factor_secret" field. +func (uuo *UserUpdateOne) ClearTwoFactorSecret() *UserUpdateOne { + uuo.mutation.ClearTwoFactorSecret() + return uuo +} + +// SetAvatar sets the "avatar" field. +func (uuo *UserUpdateOne) SetAvatar(s string) *UserUpdateOne { + uuo.mutation.SetAvatar(s) + return uuo +} + +// SetNillableAvatar sets the "avatar" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableAvatar(s *string) *UserUpdateOne { + if s != nil { + uuo.SetAvatar(*s) + } + return uuo +} + +// ClearAvatar clears the value of the "avatar" field. +func (uuo *UserUpdateOne) ClearAvatar() *UserUpdateOne { + uuo.mutation.ClearAvatar() + return uuo +} + +// SetSettings sets the "settings" field. +func (uuo *UserUpdateOne) SetSettings(ts *types.UserSetting) *UserUpdateOne { + uuo.mutation.SetSettings(ts) + return uuo +} + +// ClearSettings clears the value of the "settings" field. +func (uuo *UserUpdateOne) ClearSettings() *UserUpdateOne { + uuo.mutation.ClearSettings() + return uuo +} + +// SetGroupUsers sets the "group_users" field. +func (uuo *UserUpdateOne) SetGroupUsers(i int) *UserUpdateOne { + uuo.mutation.SetGroupUsers(i) + return uuo +} + +// SetNillableGroupUsers sets the "group_users" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableGroupUsers(i *int) *UserUpdateOne { + if i != nil { + uuo.SetGroupUsers(*i) + } + return uuo +} + +// SetGroupID sets the "group" edge to the Group entity by ID. +func (uuo *UserUpdateOne) SetGroupID(id int) *UserUpdateOne { + uuo.mutation.SetGroupID(id) + return uuo +} + +// SetGroup sets the "group" edge to the Group entity. +func (uuo *UserUpdateOne) SetGroup(g *Group) *UserUpdateOne { + return uuo.SetGroupID(g.ID) +} + +// AddFileIDs adds the "files" edge to the File entity by IDs. +func (uuo *UserUpdateOne) AddFileIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddFileIDs(ids...) + return uuo +} + +// AddFiles adds the "files" edges to the File entity. +func (uuo *UserUpdateOne) AddFiles(f ...*File) *UserUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return uuo.AddFileIDs(ids...) +} + +// AddDavAccountIDs adds the "dav_accounts" edge to the DavAccount entity by IDs. +func (uuo *UserUpdateOne) AddDavAccountIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddDavAccountIDs(ids...) + return uuo +} + +// AddDavAccounts adds the "dav_accounts" edges to the DavAccount entity. +func (uuo *UserUpdateOne) AddDavAccounts(d ...*DavAccount) *UserUpdateOne { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return uuo.AddDavAccountIDs(ids...) +} + +// AddShareIDs adds the "shares" edge to the Share entity by IDs. +func (uuo *UserUpdateOne) AddShareIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddShareIDs(ids...) + return uuo +} + +// AddShares adds the "shares" edges to the Share entity. +func (uuo *UserUpdateOne) AddShares(s ...*Share) *UserUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return uuo.AddShareIDs(ids...) +} + +// AddPasskeyIDs adds the "passkey" edge to the Passkey entity by IDs. +func (uuo *UserUpdateOne) AddPasskeyIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddPasskeyIDs(ids...) + return uuo +} + +// AddPasskey adds the "passkey" edges to the Passkey entity. +func (uuo *UserUpdateOne) AddPasskey(p ...*Passkey) *UserUpdateOne { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uuo.AddPasskeyIDs(ids...) +} + +// AddTaskIDs adds the "tasks" edge to the Task entity by IDs. +func (uuo *UserUpdateOne) AddTaskIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddTaskIDs(ids...) + return uuo +} + +// AddTasks adds the "tasks" edges to the Task entity. +func (uuo *UserUpdateOne) AddTasks(t ...*Task) *UserUpdateOne { + ids := make([]int, len(t)) + for i := range t { + ids[i] = t[i].ID + } + return uuo.AddTaskIDs(ids...) +} + +// AddEntityIDs adds the "entities" edge to the Entity entity by IDs. +func (uuo *UserUpdateOne) AddEntityIDs(ids ...int) *UserUpdateOne { + uuo.mutation.AddEntityIDs(ids...) + return uuo +} + +// AddEntities adds the "entities" edges to the Entity entity. +func (uuo *UserUpdateOne) AddEntities(e ...*Entity) *UserUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return uuo.AddEntityIDs(ids...) +} + +// Mutation returns the UserMutation object of the builder. +func (uuo *UserUpdateOne) Mutation() *UserMutation { + return uuo.mutation +} + +// ClearGroup clears the "group" edge to the Group entity. +func (uuo *UserUpdateOne) ClearGroup() *UserUpdateOne { + uuo.mutation.ClearGroup() + return uuo +} + +// ClearFiles clears all "files" edges to the File entity. +func (uuo *UserUpdateOne) ClearFiles() *UserUpdateOne { + uuo.mutation.ClearFiles() + return uuo +} + +// RemoveFileIDs removes the "files" edge to File entities by IDs. +func (uuo *UserUpdateOne) RemoveFileIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemoveFileIDs(ids...) + return uuo +} + +// RemoveFiles removes "files" edges to File entities. +func (uuo *UserUpdateOne) RemoveFiles(f ...*File) *UserUpdateOne { + ids := make([]int, len(f)) + for i := range f { + ids[i] = f[i].ID + } + return uuo.RemoveFileIDs(ids...) +} + +// ClearDavAccounts clears all "dav_accounts" edges to the DavAccount entity. +func (uuo *UserUpdateOne) ClearDavAccounts() *UserUpdateOne { + uuo.mutation.ClearDavAccounts() + return uuo +} + +// RemoveDavAccountIDs removes the "dav_accounts" edge to DavAccount entities by IDs. +func (uuo *UserUpdateOne) RemoveDavAccountIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemoveDavAccountIDs(ids...) + return uuo +} + +// RemoveDavAccounts removes "dav_accounts" edges to DavAccount entities. +func (uuo *UserUpdateOne) RemoveDavAccounts(d ...*DavAccount) *UserUpdateOne { + ids := make([]int, len(d)) + for i := range d { + ids[i] = d[i].ID + } + return uuo.RemoveDavAccountIDs(ids...) +} + +// ClearShares clears all "shares" edges to the Share entity. +func (uuo *UserUpdateOne) ClearShares() *UserUpdateOne { + uuo.mutation.ClearShares() + return uuo +} + +// RemoveShareIDs removes the "shares" edge to Share entities by IDs. +func (uuo *UserUpdateOne) RemoveShareIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemoveShareIDs(ids...) + return uuo +} + +// RemoveShares removes "shares" edges to Share entities. +func (uuo *UserUpdateOne) RemoveShares(s ...*Share) *UserUpdateOne { + ids := make([]int, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return uuo.RemoveShareIDs(ids...) +} + +// ClearPasskey clears all "passkey" edges to the Passkey entity. +func (uuo *UserUpdateOne) ClearPasskey() *UserUpdateOne { + uuo.mutation.ClearPasskey() + return uuo +} + +// RemovePasskeyIDs removes the "passkey" edge to Passkey entities by IDs. +func (uuo *UserUpdateOne) RemovePasskeyIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemovePasskeyIDs(ids...) + return uuo +} + +// RemovePasskey removes "passkey" edges to Passkey entities. +func (uuo *UserUpdateOne) RemovePasskey(p ...*Passkey) *UserUpdateOne { + ids := make([]int, len(p)) + for i := range p { + ids[i] = p[i].ID + } + return uuo.RemovePasskeyIDs(ids...) +} + +// ClearTasks clears all "tasks" edges to the Task entity. +func (uuo *UserUpdateOne) ClearTasks() *UserUpdateOne { + uuo.mutation.ClearTasks() + return uuo +} + +// RemoveTaskIDs removes the "tasks" edge to Task entities by IDs. +func (uuo *UserUpdateOne) RemoveTaskIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemoveTaskIDs(ids...) + return uuo +} + +// RemoveTasks removes "tasks" edges to Task entities. +func (uuo *UserUpdateOne) RemoveTasks(t ...*Task) *UserUpdateOne { + ids := make([]int, len(t)) + for i := range t { + ids[i] = t[i].ID + } + return uuo.RemoveTaskIDs(ids...) +} + +// ClearEntities clears all "entities" edges to the Entity entity. +func (uuo *UserUpdateOne) ClearEntities() *UserUpdateOne { + uuo.mutation.ClearEntities() + return uuo +} + +// RemoveEntityIDs removes the "entities" edge to Entity entities by IDs. +func (uuo *UserUpdateOne) RemoveEntityIDs(ids ...int) *UserUpdateOne { + uuo.mutation.RemoveEntityIDs(ids...) + return uuo +} + +// RemoveEntities removes "entities" edges to Entity entities. +func (uuo *UserUpdateOne) RemoveEntities(e ...*Entity) *UserUpdateOne { + ids := make([]int, len(e)) + for i := range e { + ids[i] = e[i].ID + } + return uuo.RemoveEntityIDs(ids...) +} + +// Where appends a list predicates to the UserUpdate builder. +func (uuo *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { + uuo.mutation.Where(ps...) + return uuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (uuo *UserUpdateOne) Select(field string, fields ...string) *UserUpdateOne { + uuo.fields = append([]string{field}, fields...) + return uuo +} + +// Save executes the query and returns the updated User entity. +func (uuo *UserUpdateOne) Save(ctx context.Context) (*User, error) { + if err := uuo.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, uuo.sqlSave, uuo.mutation, uuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (uuo *UserUpdateOne) SaveX(ctx context.Context) *User { + node, err := uuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (uuo *UserUpdateOne) Exec(ctx context.Context) error { + _, err := uuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (uuo *UserUpdateOne) ExecX(ctx context.Context) { + if err := uuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (uuo *UserUpdateOne) defaults() error { + if _, ok := uuo.mutation.UpdatedAt(); !ok { + if user.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized user.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := user.UpdateDefaultUpdatedAt() + uuo.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (uuo *UserUpdateOne) check() error { + if v, ok := uuo.mutation.Email(); ok { + if err := user.EmailValidator(v); err != nil { + return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} + } + } + if v, ok := uuo.mutation.Nick(); ok { + if err := user.NickValidator(v); err != nil { + return &ValidationError{Name: "nick", err: fmt.Errorf(`ent: validator failed for field "User.nick": %w`, err)} + } + } + if v, ok := uuo.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } + if _, ok := uuo.mutation.GroupID(); uuo.mutation.GroupCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "User.group"`) + } + return nil +} + +func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { + if err := uuo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) + id, ok := uuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "User.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := uuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) + for _, f := range fields { + if !user.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != user.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := uuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := uuo.mutation.UpdatedAt(); ok { + _spec.SetField(user.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := uuo.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + } + if uuo.mutation.DeletedAtCleared() { + _spec.ClearField(user.FieldDeletedAt, field.TypeTime) + } + if value, ok := uuo.mutation.Email(); ok { + _spec.SetField(user.FieldEmail, field.TypeString, value) + } + if value, ok := uuo.mutation.Nick(); ok { + _spec.SetField(user.FieldNick, field.TypeString, value) + } + if value, ok := uuo.mutation.Password(); ok { + _spec.SetField(user.FieldPassword, field.TypeString, value) + } + if uuo.mutation.PasswordCleared() { + _spec.ClearField(user.FieldPassword, field.TypeString) + } + if value, ok := uuo.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + } + if value, ok := uuo.mutation.Storage(); ok { + _spec.SetField(user.FieldStorage, field.TypeInt64, value) + } + if value, ok := uuo.mutation.AddedStorage(); ok { + _spec.AddField(user.FieldStorage, field.TypeInt64, value) + } + if value, ok := uuo.mutation.TwoFactorSecret(); ok { + _spec.SetField(user.FieldTwoFactorSecret, field.TypeString, value) + } + if uuo.mutation.TwoFactorSecretCleared() { + _spec.ClearField(user.FieldTwoFactorSecret, field.TypeString) + } + if value, ok := uuo.mutation.Avatar(); ok { + _spec.SetField(user.FieldAvatar, field.TypeString, value) + } + if uuo.mutation.AvatarCleared() { + _spec.ClearField(user.FieldAvatar, field.TypeString) + } + if value, ok := uuo.mutation.Settings(); ok { + _spec.SetField(user.FieldSettings, field.TypeJSON, value) + } + if uuo.mutation.SettingsCleared() { + _spec.ClearField(user.FieldSettings, field.TypeJSON) + } + if uuo.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: user.GroupTable, + Columns: []string{user.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: user.GroupTable, + Columns: []string{user.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedFilesIDs(); len(nodes) > 0 && !uuo.mutation.FilesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.FilesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.FilesTable, + Columns: []string{user.FilesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(file.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.DavAccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedDavAccountsIDs(); len(nodes) > 0 && !uuo.mutation.DavAccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.DavAccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.DavAccountsTable, + Columns: []string{user.DavAccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(davaccount.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedSharesIDs(); len(nodes) > 0 && !uuo.mutation.SharesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.SharesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.SharesTable, + Columns: []string{user.SharesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(share.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.PasskeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedPasskeyIDs(); len(nodes) > 0 && !uuo.mutation.PasskeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.PasskeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PasskeyTable, + Columns: []string{user.PasskeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(passkey.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.TasksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedTasksIDs(); len(nodes) > 0 && !uuo.mutation.TasksCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.TasksIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.TasksTable, + Columns: []string{user.TasksColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(task.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if uuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedEntitiesIDs(); len(nodes) > 0 && !uuo.mutation.EntitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.EntitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.EntitiesTable, + Columns: []string{user.EntitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(entity.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &User{config: uuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, uuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{user.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + uuo.mutation.done = true + return _node, nil +} diff --git a/go.mod b/go.mod index d32a7c7e..0d57b859 100644 --- a/go.mod +++ b/go.mod @@ -1,177 +1,167 @@ -module github.com/cloudreve/Cloudreve/v3 +module github.com/cloudreve/Cloudreve/v4 -go 1.18 +go 1.23 require ( - github.com/DATA-DOG/go-sqlmock v1.3.3 - github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible + entgo.io/ent v0.13.0 + github.com/abslant/gzip v0.0.9 + github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible github.com/aws/aws-sdk-go v1.31.5 - github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499 + github.com/cloudflare/cfssl v1.6.1 + github.com/dhowden/tag v0.0.0-20230630033851-978a0926ee25 + github.com/dsoprea/go-exif/v3 v3.0.1 + github.com/dsoprea/go-heic-exif-extractor v0.0.0-20210512044107-62067e44c235 + github.com/dsoprea/go-jpeg-image-structure v0.0.0-20221012074422-4f3f7e934102 + github.com/dsoprea/go-png-image-structure v0.0.0-20210512210324-29b889a6093d + github.com/dsoprea/go-tiff-image-structure v0.0.0-20221003165014-8ecc4f52edca + github.com/dsoprea/go-utility v0.0.0-20200711062821-fab8125e9bdf github.com/fatih/color v1.9.0 github.com/gin-contrib/cors v1.3.0 - github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8 - github.com/gin-contrib/sessions v0.0.5 + github.com/gin-contrib/sessions v1.0.2 github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2 - github.com/gin-gonic/gin v1.8.1 - github.com/glebarez/go-sqlite v1.20.3 + github.com/gin-gonic/gin v1.10.0 + github.com/glebarez/go-sqlite v1.22.0 github.com/go-ini/ini v1.50.0 github.com/go-mail/mail v2.3.1+incompatible - github.com/go-playground/validator/v10 v10.11.0 + github.com/go-pay/gopay v1.5.109 + github.com/go-playground/validator/v10 v10.20.0 + github.com/go-sql-driver/mysql v1.6.0 + github.com/go-webauthn/webauthn v0.11.2 github.com/gofrs/uuid v4.0.0+incompatible - github.com/gomodule/redigo v2.0.0+incompatible - github.com/google/go-querystring v1.0.0 - github.com/gorilla/securecookie v1.1.1 - github.com/gorilla/sessions v1.2.1 - github.com/gorilla/websocket v1.4.2 - github.com/hashicorp/go-version v1.3.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/gomodule/redigo v1.9.2 + github.com/google/go-querystring v1.1.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/securecookie v1.1.2 + github.com/gorilla/sessions v1.2.2 + github.com/gorilla/websocket v1.5.0 + github.com/huaweicloud/huaweicloud-sdk-go-obs v3.24.6+incompatible github.com/jinzhu/gorm v1.9.11 + github.com/jpillora/backoff v1.0.0 github.com/juju/ratelimit v1.0.1 + github.com/lib/pq v1.10.9 + github.com/logto-io/go/v2 v2.0.0 github.com/mholt/archiver/v4 v4.0.0-alpha.6 github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2 - github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.2.0 - github.com/qiniu/go-sdk/v7 v7.11.1 + github.com/qiniu/go-sdk/v7 v7.19.0 github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 github.com/robfig/cron/v3 v3.0.1 github.com/samber/lo v1.38.1 + github.com/smartwalle/alipay/v3 v3.2.25 github.com/speps/go-hashids v2.0.0+incompatible - github.com/stretchr/testify v1.7.2 - github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 - github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393 - github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393 - github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac + github.com/spf13/cobra v1.7.0 + github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 + github.com/stripe/stripe-go/v81 v81.0.0 + github.com/tencentyun/cos-go-sdk-v5 v0.7.54 + github.com/ua-parser/uap-go v0.0.0-20250213224047-9c035f085b90 github.com/upyun/go-sdk v2.1.0+incompatible + golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e golang.org/x/image v0.0.0-20211028202545-6944b10bf410 - golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba - google.golang.org/api v0.45.0 + golang.org/x/time v0.5.0 + golang.org/x/tools v0.24.0 + modernc.org/sqlite v1.28.0 ) require ( + ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 // indirect cloud.google.com/go v0.81.0 // indirect + github.com/agext/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.0.4 // indirect - github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect - github.com/beorn7/perks v1.0.1 // indirect - github.com/bgentry/speakeasy v0.1.0 // indirect + github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect - github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect - github.com/cespare/xxhash/v2 v2.1.1 // indirect - github.com/cloudflare/cfssl v1.6.1 // indirect - github.com/cncf/udpa/go v0.0.0-20210322005330-6414d713912e // indirect - github.com/coreos/go-semver v0.3.0 // indirect - github.com/coreos/go-systemd/v22 v22.3.2 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/clbanning/mxj v1.8.4 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 // indirect github.com/dsnet/compress v0.0.1 // indirect + github.com/dsoprea/go-exif/v2 v2.0.0-20200604193436-ca8584a0e1c4 // indirect + github.com/dsoprea/go-iptc v0.0.0-20200609062250-162ae6b44feb // indirect + github.com/dsoprea/go-logging v0.0.0-20200710184922-b02d349568dd // indirect + github.com/dsoprea/go-photoshop-info-format v0.0.0-20200609050348-3db9b63b202c // indirect + github.com/dsoprea/go-utility/v2 v2.0.0-20221003172846-a3e1774ef349 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d // indirect - github.com/envoyproxy/protoc-gen-validate v0.6.1 // indirect - github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect - github.com/fullstorydev/grpcurl v1.8.1 // indirect - github.com/fxamacker/cbor/v2 v2.4.0 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-playground/locales v0.14.0 // indirect - github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/go-sql-driver/mysql v1.6.0 // indirect - github.com/goccy/go-json v0.9.8 // indirect - github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v4 v4.1.0 // indirect + github.com/go-errors/errors v1.4.2 // indirect + github.com/go-jose/go-jose/v4 v4.0.4 // indirect + github.com/go-openapi/inflect v0.19.0 // indirect + github.com/go-pay/crypto v0.0.1 // indirect + github.com/go-pay/errgroup v0.0.3 // indirect + github.com/go-pay/smap v0.0.2 // indirect + github.com/go-pay/util v0.0.4 // indirect + github.com/go-pay/xlog v0.0.3 // indirect + github.com/go-pay/xtime v0.0.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-webauthn/x v0.1.14 // indirect + github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/mock v1.5.0 // indirect - github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/btree v1.0.1 // indirect - github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 // indirect - github.com/google/go-cmp v0.5.9 // indirect - github.com/google/uuid v1.3.0 // indirect - github.com/googleapis/gax-go/v2 v2.0.5 // indirect - github.com/gorilla/context v1.1.1 // indirect - github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect - github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect - github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect - github.com/inconshreveable/mousetrap v1.0.0 // indirect - github.com/jhump/protoreflect v1.8.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-tpm v0.9.1 // indirect + github.com/gorilla/context v1.1.2 // indirect + github.com/hashicorp/golang-lru v0.5.4 // indirect + github.com/hashicorp/hcl/v2 v2.13.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jmespath/go-jmespath v0.3.0 // indirect - github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.15.1 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/pgzip v1.2.5 // indirect - github.com/leodido/go-urn v1.2.1 // indirect - github.com/lib/pq v1.10.3 // indirect - github.com/mattn/go-colorable v0.1.4 // indirect - github.com/mattn/go-isatty v0.0.17 // indirect - github.com/mattn/go-runewidth v0.0.12 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/mitchellh/mapstructure v1.1.2 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.6 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/mozillazg/go-httpheader v0.2.1 // indirect + github.com/mozillazg/go-httpheader v0.4.0 // indirect github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/pelletier/go-toml/v2 v2.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pierrec/lz4/v4 v4.1.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_golang v1.10.0 // indirect - github.com/prometheus/client_model v0.2.0 // indirect - github.com/prometheus/common v0.24.0 // indirect - github.com/prometheus/procfs v0.6.0 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect - github.com/rivo/uniseg v0.2.0 // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/satori/go.uuid v1.2.0 // indirect - github.com/sirupsen/logrus v1.8.1 // indirect - github.com/soheilhy/cmux v0.1.5 // indirect - github.com/spf13/cobra v1.1.3 // indirect - github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.2.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/smartwalle/ncrypto v1.0.4 // indirect + github.com/smartwalle/ngx v1.0.9 // indirect + github.com/smartwalle/nsign v1.0.9 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/therootcompany/xz v1.0.1 // indirect - github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect - github.com/ugorji/go/codec v1.2.7 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect github.com/ulikunitz/xz v0.5.10 // indirect - github.com/urfave/cli v1.22.5 // indirect github.com/x448/float16 v0.8.4 // indirect - github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect - go.etcd.io/bbolt v1.3.5 // indirect - go.etcd.io/etcd/api/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/client/v2 v2.305.0-alpha.0 // indirect - go.etcd.io/etcd/client/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/etcdctl/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/server/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/tests/v3 v3.5.0-alpha.0 // indirect - go.etcd.io/etcd/v3 v3.5.0-alpha.0 // indirect - go.opencensus.io v0.23.0 // indirect - go.uber.org/atomic v1.7.0 // indirect - go.uber.org/multierr v1.7.0 // indirect - go.uber.org/zap v1.16.0 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect - golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect - golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect - golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.4.0 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a // indirect - google.golang.org/grpc v1.37.0 // indirect - google.golang.org/protobuf v1.28.0 // indirect + github.com/zclconf/go-cty v1.8.0 // indirect + go4.org v0.0.0-20200411211856-f5505b9728dd // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/mod v0.20.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect - gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect gopkg.in/mail.v2 v2.3.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - modernc.org/libc v1.22.2 // indirect - modernc.org/mathutil v1.5.0 // indirect - modernc.org/memory v1.5.0 // indirect - modernc.org/sqlite v1.20.3 // indirect - sigs.k8s.io/yaml v1.2.0 // indirect - + lukechampine.com/uint128 v1.3.0 // indirect + modernc.org/cc/v3 v3.41.0 // indirect + modernc.org/ccgo/v3 v3.16.15 // indirect + modernc.org/libc v1.37.6 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect ) - -replace github.com/gomodule/redigo v2.0.0+incompatible => github.com/gomodule/redigo v1.8.9 diff --git a/go.sum b/go.sum deleted file mode 100644 index 64a345d6..00000000 --- a/go.sum +++ /dev/null @@ -1,1464 +0,0 @@ -bazil.org/fuse v0.0.0-20180421153158-65cc252bf669/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8= -bitbucket.org/creachadair/shell v0.0.6/go.mod h1:8Qqi/cYk7vPnsOePHroKXDJYmb5x7ENhtiFtfZq8K+M= -bitbucket.org/liamstask/goose v0.0.0-20150115234039-8488cc47d90c/go.mod h1:hSVuE3qU7grINVSwrmzHfpg9k87ALBk+XaualNyUzI4= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.39.0/go.mod h1:rVLT6fkc8chs9sfPtFc1SBH6em7n+ZoXaG+87tDISts= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.44.3/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= -cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk= -cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= -cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= -cloud.google.com/go v0.81.0 h1:at8Tk2zUz63cLPR0JPWm5vp77pEZmzxEQBEfRKn1VV8= -cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/spanner v1.17.0/go.mod h1:+17t2ixFwRG4lWRwE+5kipDR9Ef07Jkmc8z0IbMDKUs= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -code.gitea.io/sdk/gitea v0.11.3/go.mod h1:z3uwDV/b9Ls47NGukYM9XhnHtqPh/J+t40lsUrR6JDY= -contrib.go.opencensus.io/exporter/aws v0.0.0-20181029163544-2befc13012d0/go.mod h1:uu1P0UCM/6RbsMrgPa98ll8ZcHM858i/AD06a9aLRCA= -contrib.go.opencensus.io/exporter/ocagent v0.5.0/go.mod h1:ImxhfLRpxoYiSq891pBrLVhN+qmP8BTVvdH2YLs7Gl0= -contrib.go.opencensus.io/exporter/stackdriver v0.12.1/go.mod h1:iwB6wGarfphGGe/e5CWqyUk/cLzKnWsOKPVW3no6OTw= -contrib.go.opencensus.io/exporter/stackdriver v0.13.5/go.mod h1:aXENhDJ1Y4lIg4EUaVTwzvYETVNZk10Pu26tevFKLUc= -contrib.go.opencensus.io/integrations/ocsql v0.1.4/go.mod h1:8DsSdjz3F+APR+0z0WkU1aRorQCFfRxvqjUUPMbF3fE= -contrib.go.opencensus.io/resource v0.1.1/go.mod h1:F361eGI91LCmW1I/Saf+rX0+OFcigGlFvXwEGEnkRLA= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/Azure/azure-amqp-common-go/v2 v2.1.0/go.mod h1:R8rea+gJRuJR6QxTir/XuEd+YuKoUiazDC/N96FiDEU= -github.com/Azure/azure-pipeline-go v0.2.1/go.mod h1:UGSo8XybXnIGZ3epmeBw7Jdz+HiUVpqIlpz/HKHylF4= -github.com/Azure/azure-sdk-for-go v29.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go v30.1.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-service-bus-go v0.9.1/go.mod h1:yzBx6/BUGfjfeqbRZny9AQIbIe3AcV9WZbAdpkoXOa0= -github.com/Azure/azure-storage-blob-go v0.8.0/go.mod h1:lPI3aLPpuLTeUwh1sViKXFxwl2B6teiRqI0deQUvsw0= -github.com/Azure/go-autorest v12.0.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08= -github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= -github.com/GeertJohan/go.rice v1.0.2/go.mod h1:af5vUNlDNkCjOZeSGFgIJxDje9qdjsO6hshx0gTmZt4= -github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20191009163259-e802c2cb94ae/go.mod h1:mjwGPas4yKduTyubHvD1Atl9r1rUq8DfVy+gkVvZ+oo= -github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible h1:aX/+gJM2dAMDDy3JqWS0DJn3JfOUchf4k37P5TbBKU8= -github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible/go.mod h1:8KDiKVrHK/UbXAhj+iQGp1m40rQa+UAvzBi7m22KywI= -github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= -github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Masterminds/semver/v3 v3.1.0/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Masterminds/sprig v2.15.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o= -github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/QcloudApi/qcloud_sign_golang v0.0.0-20141224014652-e4130a326409/go.mod h1:1pk82RBxDY/JZnPQrtqHlUFfCctgdorsd9M06fMynOM= -github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= -github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= -github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= -github.com/alcortesm/tgz v0.0.0-20161220082320-9c5fe88206d7/go.mod h1:6zEj6s6u/ghQa61ZWa/C2Aw3RkjiTBOix7dkqa1VLIs= -github.com/alecthomas/kingpin v2.2.6+incompatible/go.mod h1:59OFYbFVLKQKq+mqrL6Rw5bR0c3ACQaawgXx0QYndlE= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= -github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= -github.com/aokoli/goutils v1.0.1/go.mod h1:SijmP0QR8LtwsmDs8Yii5Z/S4trXFGFC2oO5g9DP+DQ= -github.com/apache/beam v2.28.0+incompatible/go.mod h1:/8NX3Qi8vGstDLLaeaU7+lzVEu/ACaQhYjeefzQ0y1o= -github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apex/log v1.1.4/go.mod h1:AlpoD9aScyQfJDVHmLMEcx4oU6LqzkWp4Mg9GdAcEvQ= -github.com/apex/logs v0.0.4/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDwo= -github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy8kCu4PNA+aP7WUV72eXWJeP9/r3/K9aLE= -github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= -github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= -github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= -github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= -github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= -github.com/aws/aws-sdk-go v1.19.18/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.19.45/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.25.11/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go v1.31.5 h1:DFA7BzTydO4etqsTja+x7UfkOKQUv1xzEluLvNk81L0= -github.com/aws/aws-sdk-go v1.31.5/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= -github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= -github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= -github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f h1:ZNv7On9kyUzm7fvRZumSyy/IUiSC7AzL0I1jKKtwooA= -github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= -github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb/go.mod h1:PkYb9DJNAwrSvRx5DYA+gUcOIgTGVMNkfSCbZM8cWpI= -github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= -github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/caarlos0/ctrlc v1.0.0/go.mod h1:CdXpj4rmq0q/1Eb44M9zi2nKB0QraNKuRGYGrrHhcQw= -github.com/campoy/unique v0.0.0-20180121183637-88950e537e7e/go.mod h1:9IOqJGCPMSc6E5ydlp5NIonxObaeu/Iub/X03EKPVYo= -github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= -github.com/cavaliercoder/go-cpio v0.0.0-20180626203310-925f9528c45e/go.mod h1:oDpT4efm8tSYHXV5tHSdRvBet/b/QzxZ+XyyPehvm3A= -github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= -github.com/census-instrumentation/opencensus-proto v0.2.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/census-instrumentation/opencensus-proto v0.3.0 h1:t/LhUZLVitR1Ow2YOnduCsavhwFUklBMoGVYUCqmCqk= -github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= -github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d h1:S2NE3iHSwP0XV47EEXL8mWmRdEfGscSJ+7EgePNgt0s= -github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= -github.com/cloudflare/cfssl v1.6.1 h1:aIOUjpeuDJOpWjVJFP2ByplF53OgqG8I1S40Ggdlk3g= -github.com/cloudflare/cfssl v1.6.1/go.mod h1:ENhCj4Z17+bY2XikpxVmTHDg/C2IsG2Q0ZBeXpAqhCk= -github.com/cloudflare/redoctober v0.0.0-20201013214028-99c99a8e7544/go.mod h1:6Se34jNoqrd8bTxrmJB2Bg2aoZ2CdSXonils9NsiNgo= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20210322005330-6414d713912e h1:xjKi0OrdbKVCLWRoF2SGNnv9todhp+zQlvRHhsb14R4= -github.com/cncf/udpa/go v0.0.0-20210322005330-6414d713912e/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= -github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5 h1:xD/lrqdvwsc+O2bjSSi3YqY73Ke3LAiSCx49aCesA0E= -github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo= -github.com/cockroachdb/errors v1.2.4 h1:Lap807SXTH5tri2TivECb/4abUkMZC9zRoLarvcKDqs= -github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoCr5oaCLELYA= -github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f h1:o/kfcElHqOiXqcou5a3rIlMc7oJbMQkeLk0VQJ7zgqY= -github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI= -github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= -github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= -github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= -github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CLnGVd57E= -github.com/daaku/go.zipexe v1.0.1/go.mod h1:5xWogtqlYnfBXkSB1o9xysukNP9GTvaNkqzUZbt3Bw8= -github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 h1:tkum0XDgfR0jcVVXuTsYv/erY2NnEDqwRojbxR1rBYA= -github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= -github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= -github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8= -github.com/dsnet/compress v0.0.1 h1:PlZu0n3Tuv04TzpfPbrnI0HW/YwodEXDS+oPKahKF0Q= -github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= -github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= -github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499 h1:jaQHuGKk9NVcfu9VbA7ygslr/7utxdYs47i4osBhZP8= -github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499/go.mod h1:UMk1JMDgQDcdI2vQz+WJOIUTSjIq07qSepAVgc93rUc= -github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= -github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= -github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o= -github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= -github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d h1:QyzYnTnPE15SQyUeqU6qLbWxMkwyAyu+vGksa0b7j00= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v0.3.0-java/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v0.6.1 h1:4CF52PCseTFt4bE+Yk3dIpdVi7XWuPVMhPtm4FaIJPM= -github.com/envoyproxy/protoc-gen-validate v0.6.1/go.mod h1:txg5va2Qkip90uYoSKH+nkAAmXrb2j3iq4FLwdrCbXQ= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/etcd-io/gofail v0.0.0-20190801230047-ad7f989257ca/go.mod h1:49H/RkXP8pKaZy4h0d+NW16rSLhyVBt4o6VLJbmOqDE= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= -github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= -github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/fortytw2/leaktest v1.2.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= -github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= -github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fullstorydev/grpcurl v1.8.0/go.mod h1:Mn2jWbdMrQGJQ8UD62uNyMumT2acsZUCkZIqFxsQf1o= -github.com/fullstorydev/grpcurl v1.8.1 h1:Pp648wlTTg3OKySeqxM5pzh8XF6vLqrm8wRq66+5Xo0= -github.com/fullstorydev/grpcurl v1.8.1/go.mod h1:3BWhvHZwNO7iLXaQlojdg5NA6SxUDePli4ecpK1N7gw= -github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= -github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= -github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= -github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/gin-contrib/cors v1.3.0 h1:PolezCc89peu+NgkIWt9OB01Kbzt6IP0J/JvkG6xxlg= -github.com/gin-contrib/cors v1.3.0/go.mod h1:artPvLlhkF7oG06nK8v3U8TNz6IeX+w1uzCSEId5/Vc= -github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8 h1:/DnKeA2+K83hkii3nqMJ5koknI+/qlojjxgcSyiAyJw= -github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8/go.mod h1:M+xPw/lXk+uAU4iYVnwPZs0iIpR/KwSQSXcJabN+gPs= -github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= -github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= -github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2 h1:xLG16iua01X7Gzms9045s2Y2niNpvSY/Zb1oBwgNYZY= -github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2/go.mod h1:VhW/Ch/3FhimwZb8Oj+qJmdMmoB8r7lmJ5auRjm50oQ= -github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= -github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= -github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= -github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= -github.com/glebarez/go-sqlite v1.20.3 h1:89BkqGOXR9oRmG58ZrzgoY/Fhy5x0M+/WV48U5zVrZ4= -github.com/glebarez/go-sqlite v1.20.3/go.mod h1:u3N6D/wftiAzIOJtZl6BmedqxmmkDfH3q+ihjqxC9u0= -github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-ini/ini v1.25.4/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= -github.com/go-ini/ini v1.50.0 h1:ogX6RS8VstVN8MJcwhEP78hHhWaI3klN02+97bByabY= -github.com/go-ini/ini v1.50.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-mail/mail v2.3.1+incompatible h1:UzNOn0k5lpfVtO31cK3hn6I4VEVGhe3lX8AJBAxXExM= -github.com/go-mail/mail v2.3.1+incompatible/go.mod h1:VPWjmmNyRsWXQZHVHT3g0YbIINUkSmuKOiLIDkWbL6M= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= -github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= -github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= -github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= -github.com/go-playground/validator/v10 v10.8.0/go.mod h1:9JhgTzTaE31GZDpH/HSvHiRJrJ3iKAgqqH0Bl/Ocjdk= -github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw= -github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= -github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/goccy/go-json v0.9.8 h1:DxXB6MLd6yyel7CLph8EwNIonUtVZd3Ue5iRcL4DQCE= -github.com/goccy/go-json v0.9.8/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= -github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= -github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= -github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.1.0 h1:XUgk2Ex5veyVFVeLm0xhusUTQybEbexJXrvPNOKkSY0= -github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= -github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= -github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v0.0.0-20210429001901-424d2337a529/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= -github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= -github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -github.com/google/certificate-transparency-go v1.0.21/go.mod h1:QeJfpSbVSfYc7RgB3gJFj9cbuQMMchQxrWXz8Ruopmg= -github.com/google/certificate-transparency-go v1.1.2-0.20210422104406-9f33727a7a18/go.mod h1:6CKh9dscIRoqc2kC6YUFICHZMT9NrClyPrRVFrdw1QQ= -github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 h1:806qveZBQtRNHroYHyg6yrsjqBJh9kIB4nfmB8uJnak= -github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92/go.mod h1:kXWPsHVPSKVuxPPG69BRtumCbAW537FydV/GH89oBhM= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM= -github.com/google/go-licenses v0.0.0-20210329231322-ce1d9163b77d/go.mod h1:+TYOmkVoJOpwnS0wfdsJCV9CoD5nJYsHoFk/0CrTK4M= -github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= -github.com/google/go-replayers/grpcreplay v0.1.0/go.mod h1:8Ig2Idjpr6gifRd6pNVggX6TC1Zw6Jx74AKp7QNH2QE= -github.com/google/go-replayers/httpreplay v0.1.0/go.mod h1:YKZViNhiGgqdBlUbI2MwGpq4pXxNmhJLPHQ7cv2b5no= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/licenseclassifier v0.0.0-20210325184830-bb04aff29e72/go.mod h1:qsqn2hxC+vURpyBRygGUuinTO42MFRLcsmQ/P8v94+M= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian v2.1.1-0.20190517191504-25dcb96d9e51+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0/go.mod h1:RaTPr0KUf2K7fnZYLNDrr8rxAamWs3iNywJLtQ2AzBg= -github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/trillian v1.3.14-0.20210409160123-c5ea3abd4a41/go.mod h1:1dPv0CUjNQVFEDuAUFhZql16pw/VlPgaX8qj+g5pVzQ= -github.com/google/trillian v1.3.14-0.20210428093031-b4ddea2e86b1/go.mod h1:FdIJX+NoDk/dIN2ZxTyz5nAJWgf+NSSSriPAMThChTY= -github.com/google/uuid v0.0.0-20161128191214-064e2069ce9c/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/wire v0.3.0/go.mod h1:i1DMg/Lu8Sz5yYl25iOdmc5CT5qusaa+zmRWs16741s= -github.com/googleapis/gax-go v2.0.2+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= -github.com/goreleaser/goreleaser v0.134.0/go.mod h1:ZT6Y2rSYa6NxQzIsdfWWNWAlYGXGbreo66NmE+3X3WQ= -github.com/goreleaser/nfpm v1.2.1/go.mod h1:TtWrABZozuLOttX2uDlYyECfQX7x5XYkVxhjYcR6G9w= -github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-middleware v1.2.2/go.mod h1:EaizFBKfUKtMIF5iaDEhniwNedqGo9FuLFzppDr3uwI= -github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= -github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.8.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-gateway v1.9.2/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-gateway v1.14.6/go.mod h1:zdiPV4Yse/1gnckTHtghG4GkDEdKCRJduHpTxT3/jcw= -github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= -github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= -github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= -github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= -github.com/hashicorp/go-retryablehttp v0.6.4/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= -github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/go-version v1.3.0 h1:McDWVJIU/y+u1BRV06dPaLfLCaT7fUTJLp5r04x7iNw= -github.com/hashicorp/go-version v1.3.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= -github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= -github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= -github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= -github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4= -github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= -github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/imdario/mergo v0.3.4/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= -github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= -github.com/imdario/mergo v0.3.9/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= -github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= -github.com/jarcoal/httpmock v1.0.5/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= -github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= -github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/jhump/protoreflect v1.6.1/go.mod h1:RZQ/lnuN+zqeRVpQigTwO6o0AJUkxbnSnpuG7toUTG4= -github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= -github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= -github.com/jinzhu/gorm v1.9.11 h1:gaHGvE+UnWGlbWG4Y3FUwY1EcZ5n6S9WtqBA/uySMLE= -github.com/jinzhu/gorm v1.9.11/go.mod h1:bu/pK8szGZ2puuErfU0RwyeNdsf3e6nCX/noXaVxkfw= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= -github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= -github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= -github.com/jmhodges/clock v0.0.0-20160418191101-880ee4c33548/go.mod h1:hGT6jSUVzF6no3QaDSMLGLEHtHSBSefs+MgcDWnmhmo= -github.com/jmoiron/sqlx v1.3.3/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= -github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= -github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= -github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY= -github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= -github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= -github.com/kisom/goutils v1.4.3/go.mod h1:Lp5qrquG7yhYnWzZCI/68Pa/GpFynw//od6EkGnWpac= -github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= -github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= -github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= -github.com/klauspost/pgzip v1.2.5 h1:qnWYvvKqedOF2ulHpMG72XQol4ILEJ8k2wwRl/Km8oE= -github.com/klauspost/pgzip v1.2.5/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= -github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= -github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= -github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= -github.com/lyft/protoc-gen-star v0.5.1/go.mod h1:9toiA3cC7z5uVbODF7kEQ91Xn7XNFkVUl+SrEe+ZORU= -github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxmAOow= -github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= -github.com/mattn/go-shellwords v1.0.10/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= -github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-zglob v0.0.1/go.mod h1:9fxibJccNxU2cnpIKLRRFA7zX7qhkJIQWBb449FYHOo= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/mholt/archiver/v4 v4.0.0-alpha.6 h1:3wvos9Kn1GpKNBz+MpozinGREPslLo1ds1W16vTkErQ= -github.com/mholt/archiver/v4 v4.0.0-alpha.6/go.mod h1:9PTygYq90FQBWPspdwAng6dNjYiBuTYKqmA6c15KuCo= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/miekg/pkcs11 v1.0.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= -github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= -github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= -github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= -github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= -github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= -github.com/mitchellh/reflectwalk v1.0.1/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2 h1:daZqE/T/yEoKIQNd3rwNeLsiS0VpZFfJulR0t/rtgAE= -github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY= -github.com/mozillazg/go-httpheader v0.2.1 h1:geV7TrjbL8KXSyvghnFm+NyTux/hxwueTSrwhe88TQQ= -github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60= -github.com/mreiferson/go-httpclient v0.0.0-20160630210159-31f0106b4474/go.mod h1:OQA4XLvDbMgS8P0CevmM4m9Q3Jq4phKUzcocxuGJ5m8= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/mwitkow/go-proto-validators v0.0.0-20180403085117-0950a7990007/go.mod h1:m2XC9Qq0AlmmVksL6FktJCdTYyLk7V3fKyp0sl1yWQo= -github.com/mwitkow/go-proto-validators v0.2.0/go.mod h1:ZfA1hW+UH/2ZHOWvQ3HnQaU0DtnpXu850MZiy+YUgcc= -github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= -github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= -github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= -github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= -github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= -github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso= -github.com/nkovacs/streamquote v1.0.0/go.mod h1:BN+NaZ2CmdKqUuTUXUEm9j95B2TRbpOWpxbJYzzgUsc= -github.com/nwaples/rardecode/v2 v2.0.0-beta.2 h1:e3mzJFJs4k83GXBEiTaQ5HgSc/kOK8q0rDaRO0MPaOk= -github.com/nwaples/rardecode/v2 v2.0.0-beta.2/go.mod h1:yntwv/HfMc/Hbvtq9I19D1n58te3h6KsqCf3GxyfBGY= -github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= -github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/olekukonko/tablewriter v0.0.4/go.mod h1:zq6QwlOf5SlnkVbMSr5EoBv3636FWnp+qbPhuoO21uA= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= -github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= -github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= -github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= -github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= -github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= -github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= -github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= -github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= -github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= -github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= -github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= -github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE= -github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= -github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= -github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= -github.com/pquerna/otp v1.2.0 h1:/A3+Jn+cagqayeR3iHs/L62m5ue7710D35zl1zJ1kok= -github.com/pquerna/otp v1.2.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= -github.com/prometheus/client_golang v1.5.1/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= -github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_golang v1.10.0 h1:/o0BDeWzLWXNZ+4q5gXltUvaMpJqckTa+jTNoB+z4cg= -github.com/prometheus/client_golang v1.10.0/go.mod h1:WJM3cc3yu7XKBKa/I8WeZm+V3eltZnBwfENSU7mdogU= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= -github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= -github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= -github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= -github.com/prometheus/common v0.18.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= -github.com/prometheus/common v0.24.0 h1:aIycr3wRFxPUq8XlLQlGQ9aNXV3dFi5y62pe/SB262k= -github.com/prometheus/common v0.24.0/go.mod h1:H6QK/N6XVT42whUeIdI3dp36w49c+/iMDk7UAI2qm7Q= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= -github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= -github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= -github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= -github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/pseudomuto/protoc-gen-doc v1.4.1/go.mod h1:exDTOVwqpp30eV/EDPFLZy3Pwr2sn6hBC1WIYH/UbIg= -github.com/pseudomuto/protokit v0.2.0/go.mod h1:2PdH30hxVHsup8KpBTOXTBeMVhJZVio3Q8ViKSAXT0Q= -github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk= -github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs= -github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w= -github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= -github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs= -github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY= -github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= -github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= -github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= -github.com/rogpeppe/fastuuid v1.1.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= -github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= -github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= -github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= -github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.0.0 h1:UVQPSSmc3qtTi+zPPkCXvZX9VvW/xT/NsRvKfwY81a8= -github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= -github.com/smartystreets/go-aws-auth v0.0.0-20180515143844-0c1422d1fdb9/go.mod h1:SnhjPscd9TpLiy1LpzGSKh3bXCfxxXuqd9xmQJy3slM= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4SauJk4cUOwJs= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= -github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= -github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= -github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/speps/go-hashids v2.0.0+incompatible h1:kSfxGfESueJKTx0mpER9Y/1XHl+FVQjtCqRyYcviFbw= -github.com/speps/go-hashids v2.0.0+incompatible/go.mod h1:P7hqPzMdnZOfyIk+xrlG1QaSMw+gCBdHKsBDnhpaZvc= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4= -github.com/spf13/afero v1.3.4/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= -github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= -github.com/spf13/cobra v1.1.3 h1:xghbfqPkxzxP3C/f3n5DdpAbdKLj4ZE4BWQI362l53M= -github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= -github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= -github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI= -github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= -github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= -github.com/stretchr/testify v0.0.0-20170130113145-4d4bfba8f1d1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 h1:hfhmMk7j4uDMRkfrrIOneMVXPBEhy3HSYiWX0gWoyhc= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393/go.mod h1:482ndbWuXqgStZNCqE88UoZeDveIt0juS7MY71Vangg= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393 h1:4IehmEtin8mvOO9pDA3Uj1/X9cWndyDkSsJC0AcRXv4= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393 h1:WAiJZ+YhH44DT95BlUKbcRAj1WtorJp7Lxe87v3x/F4= -github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393/go.mod h1:actV4GtZOO0fUxWese/7SNZ7l+LlepXuQyAEbC/0UMs= -github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac h1:PSBhZblOjdwH7SIVgcue+7OlnLHkM45KuScLZ+PiVbQ= -github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac/go.mod h1:wQBO5HdAkLjj2q6XQiIfDSP8DXDNrppDRw2Kp/1BODA= -github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw= -github.com/therootcompany/xz v1.0.1/go.mod h1:3K3UH1yCKgBneZYhuQUvJ9HPD19UEXEI0BWbMn8qNMY= -github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= -github.com/tj/go-elastic v0.0.0-20171221160941-36157cbbebc2/go.mod h1:WjeM0Oo1eNAjXGDx2yma7uG2XoyRZTq1uv3M/o7imD0= -github.com/tj/go-kinesis v0.0.0-20171128231115-08b17f58cb1b/go.mod h1:/yhzCV0xPfx6jb1bBgRFjl5lytqVqZXEaeqWP8lTEao= -github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKwh4= -github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= -github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= -github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= -github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= -github.com/ulikunitz/xz v0.5.7/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= -github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/upyun/go-sdk v2.1.0+incompatible h1:OdjXghQ/TVetWV16Pz3C1/SUpjhGBVPr+cLiqZLLyq0= -github.com/upyun/go-sdk v2.1.0+incompatible/go.mod h1:eu3F5Uz4b9ZE5bE5QsCL6mgSNWRwfj0zpJ9J626HEqs= -github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= -github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU= -github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -github.com/weppos/publicsuffix-go v0.13.1-0.20210123135404-5fd73613514e/go.mod h1:HYux0V0Zi04bHNwOHy4cXJVz/TQjYonnF6aoYhj+3QE= -github.com/weppos/publicsuffix-go v0.15.1-0.20210511084619-b1f36a2d6c0b/go.mod h1:HYux0V0Zi04bHNwOHy4cXJVz/TQjYonnF6aoYhj+3QE= -github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= -github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/xanzy/go-gitlab v0.31.0/go.mod h1:sPLojNBn68fMUWSxIJtdVVIP8uSBYqesTfDUseX11Ug= -github.com/xanzy/ssh-agent v0.2.1/go.mod h1:mLlQY/MoOhWBj+gOGMQkOeiEvkx+8pJSI+0Bx9h2kr4= -github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= -github.com/zmap/rc2 v0.0.0-20131011165748-24b9757f5521/go.mod h1:3YZ9o3WnatTIZhuOtot4IcUfzoKVjUHqu6WALIyI0nE= -github.com/zmap/zcertificate v0.0.0-20180516150559-0e3d58b1bac4/go.mod h1:5iU54tB79AMBcySS0R2XIyZBAVmeHranShAFELYx7is= -github.com/zmap/zcrypto v0.0.0-20210123152837-9cf5beac6d91/go.mod h1:R/deQh6+tSWlgI9tb4jNmXxn8nSCabl5ZQsBX9//I/E= -github.com/zmap/zcrypto v0.0.0-20210511125630-18f1e0152cfc/go.mod h1:FM4U1E3NzlNMRnSUTU3P1UdukWhYGifqEsjk9fn7BCk= -github.com/zmap/zlint/v3 v3.1.0/go.mod h1:L7t8s3sEKkb0A2BxGy1IWrxt1ZATa1R4QfJZaQOD3zU= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= -go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= -go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= -go.etcd.io/etcd/api/v3 v3.5.0-alpha.0 h1:+e5nrluATIy3GP53znpkHMFzPTHGYyzvJGFCbuI6ZLc= -go.etcd.io/etcd/api/v3 v3.5.0-alpha.0/go.mod h1:mPcW6aZJukV6Aa81LSKpBjQXTWlXB5r74ymPoSWa3Sw= -go.etcd.io/etcd/client/v2 v2.305.0-alpha.0 h1:jZepGpOeJATxsbMNBZczDS2jHdK/QVHM1iPe9jURJ8o= -go.etcd.io/etcd/client/v2 v2.305.0-alpha.0/go.mod h1:kdV+xzCJ3luEBSIeQyB/OEKkWKd8Zkux4sbDeANrosU= -go.etcd.io/etcd/client/v3 v3.5.0-alpha.0 h1:dr1EOILak2pu4Nf5XbRIOCNIBjcz6UmkQd7hHRXwxaM= -go.etcd.io/etcd/client/v3 v3.5.0-alpha.0/go.mod h1:wKt7jgDgf/OfKiYmCq5WFGxOFAkVMLxiiXgLDFhECr8= -go.etcd.io/etcd/etcdctl/v3 v3.5.0-alpha.0 h1:odMFuQQCg0UmPd7Cyw6TViRYv9ybGuXuki4CusDSzqA= -go.etcd.io/etcd/etcdctl/v3 v3.5.0-alpha.0/go.mod h1:YPwSaBciV5G6Gpt435AasAG3ROetZsKNUzibRa/++oo= -go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0 h1:3yLUEC0nFCxw/RArImOyRUI4OAFbg4PFpBbAhSNzKNY= -go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0/go.mod h1:tV31atvwzcybuqejDoY3oaNRTtlD2l/Ot78Pc9w7DMY= -go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0 h1:DvYJotxV9q1Lkn7pknzAbFO/CLtCVidCr2K9qRLJ8pA= -go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0/go.mod h1:FAwse6Zlm5v4tEWZaTjmNhe17Int4Oxbu7+2r0DiD3w= -go.etcd.io/etcd/server/v3 v3.5.0-alpha.0 h1:fYv7CmmdyuIu27UmKQjS9K/1GtcCa+XnPKqiKBbQkrk= -go.etcd.io/etcd/server/v3 v3.5.0-alpha.0/go.mod h1:tsKetYpt980ZTpzl/gb+UOJj9RkIyCb1u4wjzMg90BQ= -go.etcd.io/etcd/tests/v3 v3.5.0-alpha.0 h1:UcRoCA1FgXoc4CEM8J31fqEvI69uFIObY5ZDEFH7Znc= -go.etcd.io/etcd/tests/v3 v3.5.0-alpha.0/go.mod h1:HnrHxjyCuZ8YDt8PYVyQQ5d1ZQfzJVEtQWllr5Vp/30= -go.etcd.io/etcd/v3 v3.5.0-alpha.0 h1:ZuqKJkD2HrzFUj8IB+GLkTMKZ3+7mWx172vx6F1TukM= -go.etcd.io/etcd/v3 v3.5.0-alpha.0/go.mod h1:JZ79d3LV6NUfPjUxXrpiFAYcjhT+06qqw+i28snx8To= -go.opencensus.io v0.15.0/go.mod h1:UffZAU+4sDEINUGP/B7UfBBkq4fqLu9zXAX7ke6CHW0= -go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= -go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= -go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= -go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -go.uber.org/zap v1.16.0 h1:uFRZXykJGK9lLY4HtgSw44DnIcAM+kRBP7x5m+NpAOM= -go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= -gocloud.dev v0.19.0/go.mod h1:SmKwiR8YwIMMJvQBKLsC3fHNyMwXLw3PMDO+VVteJMI= -golang.org/x/crypto v0.0.0-20180501155221-613d6eafa307/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191117063200-497ca9f6d64f/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201124201722-c8d3bf9c5392/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= -golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= -golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= -golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190619014844-b5b0513f8c1b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191119073136-fc4aabc6c914/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= -golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= -golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210413134643-5e61552d6c78/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c h1:SgVl/sCtkicsS7psKkje4H9YtjdEl3xsYh7N+5TDHqY= -golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190620070143-6f217b454f45/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191119060738-e882bf8e40c2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210412220455-f1c623a9e750/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210511113859-b0526f3d8744/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211020174200-9d6173849985/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= -golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190729092621-ff9f1409240a/go.mod h1:jcCCGcm9btYwXyDqrUWc6MKQKKGJCWEQ3AfLSRIbEuI= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191010075000-0337d82405ff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191118222007-07fc4c7f2b98/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200426102838-f3a5411a4c3b/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= -golang.org/x/tools v0.0.0-20201014170642-d1624618ad65/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= -golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI= -golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.5.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.6.0/go.mod h1:btoxGiFvQNVUZQ8W08zLtrVS08CNpINPEfxXxgJL1Q4= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.10.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= -google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= -google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= -google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= -google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= -google.golang.org/api v0.45.0 h1:pqMffJFLBVUDIoYsHcqtxgQVTsmxMDpYLOc5MT4Jrww= -google.golang.org/api v0.45.0/go.mod h1:ISLIJCedJolbZvDfAk+Ctuq5hf+aJ33WgtUsfyFoLXA= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.2/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20170818010345-ee236bd376b0/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181107211654-5fc9ac540362/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190508193815-b515fa19cec8/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= -google.golang.org/genproto v0.0.0-20190620144150-6af8c5fc6601/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210331142528-b7513248f0ba/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= -google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= -google.golang.org/genproto v0.0.0-20210413151531-c14fb6ef47c3/go.mod h1:P3QM42oQyzQSnHPnZ/vqoCdDmzH28fzWByN9asMeM8A= -google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a h1:tzkHckzMzgPr8SC4taTC3AldLr4+oJivSoq1xf/nhsc= -google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a/go.mod h1:P3QM42oQyzQSnHPnZ/vqoCdDmzH28fzWByN9asMeM8A= -google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= -google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.37.0 h1:uSZWeQJX5j11bIQ4AJoj+McDBo29cY1MCoC1wO3ts+c= -google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= -gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= -gopkg.in/cheggaaa/pb.v1 v1.0.28 h1:n1tBJnnK2r7g9OW2btFH91V92STTUevLXYFb8gy9EMk= -gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= -gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= -gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= -gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= -gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= -gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/mail.v2 v2.3.1 h1:WYFn/oANrAGP2C0dcV6/pbkPzv8yGzqTjPmTeO7qoXk= -gopkg.in/mail.v2 v2.3.1/go.mod h1:htwXN1Qh09vZJ1NVKxQqHPBaCBbzKhp5GzuJEA4VJWw= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/src-d/go-billy.v4 v4.3.2/go.mod h1:nDjArDMp+XMs1aFAESLRjfGSgfvoYN0hDfzEk0GjC98= -gopkg.in/src-d/go-git-fixtures.v3 v3.5.0/go.mod h1:dLBcvytrw/TYZsNTWCnkNF2DSIlzWYqTe3rJR56Ac7g= -gopkg.in/src-d/go-git.v4 v4.13.1/go.mod h1:nx5NYcxdKxq5fpltdHnPa2Exj4Sx0EclMWZQbYDu2z8= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.1.4 h1:SadWOkti5uVN1FAMgxn165+Mw00fuQKyk4Gyn/inxNQ= -honnef.co/go/tools v0.1.4/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= -modernc.org/libc v1.22.2 h1:4U7v51GyhlWqQmwCHj28Rdq2Yzwk55ovjFrdPjs8Hb0= -modernc.org/libc v1.22.2/go.mod h1:uvQavJ1pZ0hIoC/jfqNoMLURIMhKzINIWypNM17puug= -modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= -modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= -modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= -modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= -modernc.org/sqlite v1.20.3 h1:SqGJMMxjj1PHusLxdYxeQSodg7Jxn9WWkaAQjKrntZs= -modernc.org/sqlite v1.20.3/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A= -pack.ag/amqp v0.11.2/go.mod h1:4/cbmt4EJXSKlG6LCfWHoqmN0uFdy5i/+YFz+fTfhV4= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= -sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= -sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= -sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= -sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= diff --git a/inventory/client.go b/inventory/client.go new file mode 100644 index 00000000..d93cf52b --- /dev/null +++ b/inventory/client.go @@ -0,0 +1,454 @@ +package inventory + +import ( + "context" + rawsql "database/sql" + "database/sql/driver" + "fmt" + "os" + "time" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/ent/node" + _ "github.com/cloudreve/Cloudreve/v4/ent/runtime" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/debug" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "modernc.org/sqlite" +) + +const ( + DBVersionPrefix = "db_version_" + EnvDefaultOverwritePrefix = "CR_SETTING_DEFAULT_" + EnvEnableAria2 = "CR_ENABLE_ARIA2" +) + +// InitializeDBClient runs migration and returns a new ent.Client with additional configurations +// for hooks and interceptors. +func InitializeDBClient(l logging.Logger, + client *ent.Client, kv cache.Driver, requiredDbVersion string) (*ent.Client, error) { + ctx := context.WithValue(context.Background(), logging.LoggerCtx{}, l) + if needMigration(client, ctx, requiredDbVersion) { + // Run the auto migration tool. + if err := migrate(l, client, ctx, kv, requiredDbVersion); err != nil { + return nil, fmt.Errorf("failed to migrate database: %w", err) + } + } else { + l.Info("Database schema is up to date.") + } + + //createMockData(client, ctx) + return client, nil +} + +// NewRawEntClient returns a new ent.Client without additional configurations. +func NewRawEntClient(l logging.Logger, config conf.ConfigProvider) (*ent.Client, error) { + l.Info("Initializing database connection...") + dbConfig := config.Database() + confDBType := dbConfig.Type + if confDBType == conf.SQLite3DB || confDBType == "" { + confDBType = conf.SQLiteDB + } + + var ( + err error + client *sql.Driver + ) + + switch confDBType { + case conf.SQLiteDB: + dbFile := util.RelativePath(dbConfig.DBFile) + l.Info("Connect to SQLite database %q.", dbFile) + client, err = sql.Open("sqlite3", util.RelativePath(dbConfig.DBFile)) + case conf.PostgresDB: + l.Info("Connect to Postgres database %q.", dbConfig.Host) + client, err = sql.Open("postgres", fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable", + dbConfig.Host, + dbConfig.User, + dbConfig.Password, + dbConfig.Name, + dbConfig.Port)) + case conf.MySqlDB, conf.MsSqlDB: + l.Info("Connect to MySQL/SQLServer database %q.", dbConfig.Host) + var host string + if dbConfig.UnixSocket { + host = fmt.Sprintf("unix(%s)", + dbConfig.Host) + } else { + host = fmt.Sprintf("(%s:%d)", + dbConfig.Host, + dbConfig.Port) + } + + client, err = sql.Open(string(confDBType), fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local", + dbConfig.User, + dbConfig.Password, + host, + dbConfig.Name, + dbConfig.Charset)) + default: + return nil, fmt.Errorf("unsupported database type %q", confDBType) + } + + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + // Set connection pool + db := client.DB() + db.SetMaxIdleConns(50) + if confDBType == "sqlite" || confDBType == "UNSET" { + db.SetMaxOpenConns(1) + } else { + db.SetMaxOpenConns(100) + } + + // Set timeout + db.SetConnMaxLifetime(time.Second * 30) + + driverOpt := ent.Driver(client) + + // Enable verbose logging for debug mode. + if config.System().Debug { + l.Debug("Debug mode is enabled for DB client.") + driverOpt = ent.Driver(debug.DebugWithContext(client, func(ctx context.Context, i ...any) { + logging.FromContext(ctx).Debug(i[0].(string), i[1:]...) + })) + } + + return ent.NewClient(driverOpt), nil +} + +type sqlite3Driver struct { + *sqlite.Driver +} + +type sqlite3DriverConn interface { + Exec(string, []driver.Value) (driver.Result, error) +} + +func (d sqlite3Driver) Open(name string) (conn driver.Conn, err error) { + conn, err = d.Driver.Open(name) + if err != nil { + return + } + _, err = conn.(sqlite3DriverConn).Exec("PRAGMA foreign_keys = ON;", nil) + if err != nil { + _ = conn.Close() + } + return +} + +func init() { + rawsql.Register("sqlite3", sqlite3Driver{Driver: &sqlite.Driver{}}) +} + +// needMigration exams if required schema version is satisfied. +func needMigration(client *ent.Client, ctx context.Context, requiredDbVersion string) bool { + c, _ := client.Setting.Query().Where(setting.NameEQ(DBVersionPrefix + requiredDbVersion)).Count(ctx) + return c == 0 +} + +func migrate(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver, requiredDbVersion string) error { + l.Info("Start initializing database schema...") + l.Info("Creating basic table schema...") + if err := client.Schema.Create(ctx); err != nil { + return fmt.Errorf("Failed creating schema resources: %w", err) + } + + migrateDefaultSettings(l, client, ctx, kv) + + if err := migrateDefaultStoragePolicy(l, client, ctx); err != nil { + return fmt.Errorf("failed migrating default storage policy: %w", err) + } + + if err := migrateSysGroups(l, client, ctx); err != nil { + return fmt.Errorf("failed migrating default storage policy: %w", err) + } + + client.Setting.Create().SetName(DBVersionPrefix + requiredDbVersion).SetValue("installed").Save(ctx) + return nil +} + +func migrateDefaultSettings(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver) { + // clean kv cache + if err := kv.DeleteAll(); err != nil { + l.Warning("Failed to remove all KV entries while schema migration: %s", err) + } + + // List existing settings into a map + existingSettings := make(map[string]struct{}) + settings, err := client.Setting.Query().All(ctx) + if err != nil { + l.Warning("Failed to query existing settings: %s", err) + } + + for _, s := range settings { + existingSettings[s.Name] = struct{}{} + } + + l.Info("Insert default settings...") + for k, v := range DefaultSettings { + if _, ok := existingSettings[k]; ok { + l.Debug("Skip inserting setting %s, already exists.", k) + continue + } + + if override, ok := os.LookupEnv(EnvDefaultOverwritePrefix + k); ok { + l.Info("Override default setting %q with env value %q", k, override) + v = override + } + + client.Setting.Create().SetName(k).SetValue(v).SaveX(ctx) + } +} + +func migrateDefaultStoragePolicy(l logging.Logger, client *ent.Client, ctx context.Context) error { + if _, err := client.StoragePolicy.Query().Where(storagepolicy.ID(1)).First(ctx); err == nil { + l.Info("Default storage policy (ID=1) already exists, skip migrating.") + return nil + } + + l.Info("Insert default storage policy...") + if _, err := client.StoragePolicy.Create(). + SetName("Default storage policy"). + SetType(types.PolicyTypeLocal). + SetDirNameRule(util.DataPath("uploads/{uid}/{path}")). + SetFileNameRule("{uid}_{randomkey8}_{originname}"). + SetSettings(&types.PolicySetting{ + ChunkSize: 25 << 20, // 25MB + PreAllocate: true, + }). + Save(ctx); err != nil { + return fmt.Errorf("failed to create default storage policy: %w", err) + } + + return nil +} + +func migrateSysGroups(l logging.Logger, client *ent.Client, ctx context.Context) error { + if err := migrateAdminGroup(l, client, ctx); err != nil { + return err + } + + if err := migrateUserGroup(l, client, ctx); err != nil { + return err + } + + if err := migrateAnonymousGroup(l, client, ctx); err != nil { + return err + } + + if err := migrateMasterNode(l, client, ctx); err != nil { + return err + } + + return nil +} + +func migrateAdminGroup(l logging.Logger, client *ent.Client, ctx context.Context) error { + if _, err := client.Group.Query().Where(group.ID(1)).First(ctx); err == nil { + l.Info("Default admin group (ID=1) already exists, skip migrating.") + return nil + } + + l.Info("Insert default admin group...") + permissions := &boolset.BooleanSet{} + boolset.Sets(map[types.GroupPermission]bool{ + types.GroupPermissionIsAdmin: true, + types.GroupPermissionShare: true, + types.GroupPermissionWebDAV: true, + types.GroupPermissionWebDAVProxy: true, + types.GroupPermissionArchiveDownload: true, + types.GroupPermissionArchiveTask: true, + types.GroupPermissionShareDownload: true, + types.GroupPermissionRemoteDownload: true, + types.GroupPermissionRedirectedSource: true, + types.GroupPermissionAdvanceDelete: true, + types.GroupPermissionIgnoreFileOwnership: true, + // TODO: review default permission + }, permissions) + if _, err := client.Group.Create(). + SetName("Admin"). + SetStoragePoliciesID(1). + SetMaxStorage(1 * constants.TB). // 1 TB default storage + SetPermissions(permissions). + SetSettings(&types.GroupSetting{ + SourceBatchSize: 1000, + Aria2BatchSize: 50, + MaxWalkedFiles: 100000, + TrashRetention: 7 * 24 * 3600, + RedirectedSource: true, + }). + Save(ctx); err != nil { + return fmt.Errorf("failed to create default admin group: %w", err) + } + + return nil +} + +func migrateUserGroup(l logging.Logger, client *ent.Client, ctx context.Context) error { + if _, err := client.Group.Query().Where(group.ID(2)).First(ctx); err == nil { + l.Info("Default user group (ID=2) already exists, skip migrating.") + return nil + } + + l.Info("Insert default user group...") + permissions := &boolset.BooleanSet{} + boolset.Sets(map[types.GroupPermission]bool{ + types.GroupPermissionShare: true, + types.GroupPermissionShareDownload: true, + types.GroupPermissionRedirectedSource: true, + }, permissions) + if _, err := client.Group.Create(). + SetName("User"). + SetStoragePoliciesID(1). + SetMaxStorage(1 * constants.GB). // 1 GB default storage + SetPermissions(permissions). + SetSettings(&types.GroupSetting{ + SourceBatchSize: 10, + Aria2BatchSize: 1, + MaxWalkedFiles: 100000, + TrashRetention: 7 * 24 * 3600, + RedirectedSource: true, + }). + Save(ctx); err != nil { + return fmt.Errorf("failed to create default user group: %w", err) + } + + return nil +} + +func migrateAnonymousGroup(l logging.Logger, client *ent.Client, ctx context.Context) error { + if _, err := client.Group.Query().Where(group.ID(AnonymousGroupID)).First(ctx); err == nil { + l.Info("Default anonymous group (ID=3) already exists, skip migrating.") + return nil + } + + l.Info("Insert default anonymous group...") + permissions := &boolset.BooleanSet{} + boolset.Sets(map[types.GroupPermission]bool{ + types.GroupPermissionIsAnonymous: true, + types.GroupPermissionShareDownload: true, + }, permissions) + if _, err := client.Group.Create(). + SetName("Anonymous"). + SetPermissions(permissions). + SetSettings(&types.GroupSetting{ + MaxWalkedFiles: 100000, + RedirectedSource: true, + }). + Save(ctx); err != nil { + return fmt.Errorf("failed to create default anonymous group: %w", err) + } + + return nil +} + +func migrateMasterNode(l logging.Logger, client *ent.Client, ctx context.Context) error { + if _, err := client.Node.Query().Where(node.TypeEQ(node.TypeMaster)).First(ctx); err == nil { + l.Info("Default master node already exists, skip migrating.") + return nil + } + + capabilities := &boolset.BooleanSet{} + boolset.Sets(map[types.NodeCapability]bool{ + types.NodeCapabilityCreateArchive: true, + types.NodeCapabilityExtractArchive: true, + types.NodeCapabilityRemoteDownload: true, + }, capabilities) + + stm := client.Node.Create(). + SetType(node.TypeMaster). + SetCapabilities(capabilities). + SetName("Master"). + SetSettings(&types.NodeSetting{ + Provider: types.DownloaderProviderAria2, + }). + SetStatus(node.StatusActive) + + _, enableAria2 := os.LookupEnv(EnvEnableAria2) + if enableAria2 { + l.Info("Aria2 is override as enabled.") + stm.SetSettings(&types.NodeSetting{ + Provider: types.DownloaderProviderAria2, + Aria2Setting: &types.Aria2Setting{ + Server: "http://127.0.0.1:6800/jsonrpc", + }, + }) + } + + l.Info("Insert default master node...") + if _, err := stm.Save(ctx); err != nil { + return fmt.Errorf("failed to create default master node: %w", err) + } + + return nil +} + +func createMockData(client *ent.Client, ctx context.Context) { + //userCount := 100 + //folderCount := 10000 + //fileCount := 25000 + // + //// create users + //pwdDigest, _ := digestPassword("52121225") + //userCreates := make([]*ent.UserCreate, userCount) + //for i := 0; i < userCount; i++ { + // nick := uuid.Must(uuid.NewV4()).String() + // userCreates[i] = client.User.Create(). + // SetEmail(nick + "@cloudreve.org"). + // SetNick(nick). + // SetPassword(pwdDigest). + // SetStatus(user.StatusActive). + // SetGroupID(1) + //} + //users, err := client.User.CreateBulk(userCreates...).Save(ctx) + //if err != nil { + // panic(err) + //} + // + //// Create root folder + //rootFolderCreates := make([]*ent.FileCreate, userCount) + //folderIds := make([][]int, 0, folderCount*userCount+userCount) + //for i, user := range users { + // rootFolderCreates[i] = client.File.Create(). + // SetName(RootFolderName). + // SetOwnerID(user.ID). + // SetType(int(FileTypeFolder)) + //} + //rootFolders, err := client.File.CreateBulk(rootFolderCreates...).Save(ctx) + //for _, rootFolders := range rootFolders { + // folderIds = append(folderIds, []int{rootFolders.ID, rootFolders.OwnerID}) + //} + //if err != nil { + // panic(err) + //} + // + //// create random folder + //for i := 0; i < folderCount*userCount; i++ { + // parent := lo.Sample(folderIds) + // res := client.File.Create(). + // SetName(uuid.Must(uuid.NewV4()).String()). + // SetType(int(FileTypeFolder)). + // SetOwnerID(parent[1]). + // SetFileChildren(parent[0]). + // SaveX(ctx) + // folderIds = append(folderIds, []int{res.ID, res.OwnerID}) + //} + + for i := 0; i < 255; i++ { + fmt.Printf("%d/", i) + } +} diff --git a/inventory/common.go b/inventory/common.go new file mode 100644 index 00000000..5db42fed --- /dev/null +++ b/inventory/common.go @@ -0,0 +1,124 @@ +package inventory + +import ( + "encoding/base64" + "encoding/json" + "entgo.io/ent/dialect/sql" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "time" +) + +type ( + OrderDirection string + PaginationArgs struct { + UseCursorPagination bool + Page int + PageToken string + PageSize int + OrderBy string + Order OrderDirection + } + + PaginationResults struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int `json:"total_items,omitempty"` + NextPageToken string `json:"next_token,omitempty"` + IsCursor bool `json:"is_cursor,omitempty"` + } + + PageToken struct { + Time *time.Time `json:"time,omitempty"` + ID int `json:"-"` + IDHash string `json:"id,omitempty"` + String string `json:"string,omitempty"` + Int int `json:"int,omitempty"` + StartWithFile bool `json:"start_with_file,omitempty"` + } +) + +const ( + OrderDirectionAsc = OrderDirection("asc") + OrderDirectionDesc = OrderDirection("desc") +) + +var ( + ErrTooManyArguments = fmt.Errorf("too many arguments") +) + +func pageTokenFromString(s string, hasher hashid.Encoder, idType int) (*PageToken, error) { + sB64Decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 for page token: %w", err) + } + + token := &PageToken{} + if err := json.Unmarshal(sB64Decoded, token); err != nil { + return nil, fmt.Errorf("failed to unmarshal page token: %w", err) + } + + id, err := hasher.Decode(token.IDHash, idType) + if err != nil { + return nil, fmt.Errorf("failed to decode id: %w", err) + } + + if token.Time == nil { + token.Time = &time.Time{} + } + + token.ID = id + return token, nil +} + +func (p *PageToken) Encode(hasher hashid.Encoder, encodeFunc hashid.EncodeFunc) (string, error) { + p.IDHash = encodeFunc(hasher, p.ID) + res, err := json.Marshal(p) + if err != nil { + return "", fmt.Errorf("failed to marshal page token: %w", err) + } + + return base64.StdEncoding.EncodeToString(res), nil +} + +// sqlParamLimit returns the max number of sql parameters. +func sqlParamLimit(dbType conf.DBType) int { + switch dbType { + case conf.PostgresDB: + return 34464 + case conf.SQLiteDB, conf.SQLite3DB: + // https://www.sqlite.org/limits.html + return 32766 + default: + return 32766 + } +} + +// getOrderTerm returns the order term for ent. +func getOrderTerm(d OrderDirection) sql.OrderTermOption { + switch d { + case OrderDirectionDesc: + return sql.OrderDesc() + default: + return sql.OrderAsc() + } +} + +func capPageSize(maxSQlParam, preferredSize, margin int) int { + // Page size should not be bigger than max SQL parameter + pageSize := preferredSize + if maxSQlParam > 0 && pageSize > maxSQlParam-margin || pageSize == 0 { + pageSize = maxSQlParam - margin + } + + return pageSize +} + +type StorageDiff map[int]int64 + +func (s *StorageDiff) Merge(diff StorageDiff) { + for k, v := range diff { + (*s)[k] += v + } +} diff --git a/inventory/dav_account.go b/inventory/dav_account.go new file mode 100644 index 00000000..d2f7e3cc --- /dev/null +++ b/inventory/dav_account.go @@ -0,0 +1,187 @@ +package inventory + +import ( + "context" + "entgo.io/ent/dialect/sql" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/samber/lo" +) + +type ( + DavAccountClient interface { + TxOperator + // List returns a list of dav accounts with the given args. + List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error) + // Create creates a new dav account. + Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error) + // Update updates a dav account. + Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error) + // GetByIDAndUserID returns the dav account with given id and user id. + GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error) + // Delete deletes the dav account. + Delete(ctx context.Context, id int) error + } + + ListDavAccountArgs struct { + *PaginationArgs + UserID int + } + ListDavAccountResult struct { + *PaginationResults + Accounts []*ent.DavAccount + } + CreateDavAccountParams struct { + UserID int + Name string + URI string + Password string + Options *boolset.BooleanSet + } +) + +func NewDavAccountClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DavAccountClient { + return &davAccountClient{ + client: client, + hasher: hasher, + maxSQlParam: sqlParamLimit(dbType), + } +} + +type davAccountClient struct { + maxSQlParam int + client *ent.Client + hasher hashid.Encoder +} + +func (c *davAccountClient) SetClient(newClient *ent.Client) TxOperator { + return &davAccountClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam} +} + +func (c *davAccountClient) GetClient() *ent.Client { + return c.client +} + +func (c *davAccountClient) Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error) { + account := c.client.DavAccount.Create(). + SetOwnerID(params.UserID). + SetName(params.Name). + SetURI(params.URI). + SetPassword(params.Password). + SetOptions(params.Options) + + return account.Save(ctx) +} + +func (c *davAccountClient) GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error) { + return c.client.DavAccount.Query(). + Where(davaccount.ID(id), davaccount.OwnerID(userID)). + First(ctx) +} + +func (c *davAccountClient) Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error) { + account := c.client.DavAccount.UpdateOneID(id). + SetName(params.Name). + SetURI(params.URI). + SetOptions(params.Options) + + return account.Save(ctx) +} + +func (c *davAccountClient) Delete(ctx context.Context, id int) error { + return c.client.DavAccount.DeleteOneID(id).Exec(ctx) +} + +func (c *davAccountClient) List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error) { + query := c.listQuery(args) + + var ( + accounts []*ent.DavAccount + err error + paginationRes *PaginationResults + ) + accounts, paginationRes, err = c.cursorPagination(ctx, query, args, 10) + + if err != nil { + return nil, fmt.Errorf("query failed with paginiation: %w", err) + } + + return &ListDavAccountResult{ + Accounts: accounts, + PaginationResults: paginationRes, + }, nil +} + +func (c *davAccountClient) cursorPagination(ctx context.Context, query *ent.DavAccountQuery, args *ListDavAccountArgs, paramMargin int) ([]*ent.DavAccount, *PaginationResults, error) { + pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) + query.Order(davaccount.ByID(sql.OrderDesc())) + + var ( + pageToken *PageToken + err error + ) + if args.PageToken != "" { + pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.DavAccountID) + if err != nil { + return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err) + } + } + queryPaged := getDavAccountCursorQuery(pageToken, query) + + // Use page size + 1 to determine if there are more items to come + queryPaged.Limit(pageSize + 1) + + logs, err := queryPaged. + All(ctx) + if err != nil { + return nil, nil, err + } + + // More items to come + nextTokenStr := "" + if len(logs) > pageSize { + lastItem := logs[len(logs)-2] + nextToken, err := getDavAccountNextPageToken(c.hasher, lastItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate next page token: %w", err) + } + + nextTokenStr = nextToken + } + + return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{ + PageSize: pageSize, + NextPageToken: nextTokenStr, + IsCursor: true, + }, nil +} + +func (c *davAccountClient) listQuery(args *ListDavAccountArgs) *ent.DavAccountQuery { + query := c.client.DavAccount.Query() + if args.UserID > 0 { + query.Where(davaccount.OwnerID(args.UserID)) + } + + return query +} + +// getDavAccountNextPageToken returns the next page token for the given last dav account. +func getDavAccountNextPageToken(hasher hashid.Encoder, last *ent.DavAccount) (string, error) { + token := &PageToken{ + ID: last.ID, + } + + return token.Encode(hasher, hashid.EncodeDavAccountID) +} + +func getDavAccountCursorQuery(token *PageToken, query *ent.DavAccountQuery) *ent.DavAccountQuery { + if token != nil { + query.Where(davaccount.IDLT(token.ID)) + } + + return query +} diff --git a/inventory/debug/debug.go b/inventory/debug/debug.go new file mode 100644 index 00000000..e327f6d8 --- /dev/null +++ b/inventory/debug/debug.go @@ -0,0 +1,188 @@ +package debug + +import ( + "context" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "fmt" + "github.com/google/uuid" + "time" +) + +const strMaxLen = 102400 + +type SkipDbLogging struct{} + +// DebugDriver is a driver that logs all driver operations. +type DebugDriver struct { + dialect.Driver // underlying driver. + log func(context.Context, ...any) // log function. defaults to log.Println. +} + +// DebugWithContext gets a driver and a logging function, and returns +// a new debugged-driver that prints all outgoing operations with context. +func DebugWithContext(d dialect.Driver, logger func(context.Context, ...any)) dialect.Driver { + drv := &DebugDriver{d, logger} + return drv +} + +// Exec logs its params and calls the underlying driver Exec method. +func (d *DebugDriver) Exec(ctx context.Context, query string, args, v any) error { + start := time.Now() + err := d.Driver.Exec(ctx, query, args, v) + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return err + } + + d.log(ctx, fmt.Sprintf("driver.Exec: query=%v args=%v time=%v", query, args, time.Since(start))) + return err +} + +// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported. +func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + drv, ok := d.Driver.(interface { + ExecContext(context.Context, string, ...any) (sql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.ExecContext is not supported") + } + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return drv.ExecContext(ctx, query, args...) + } + d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args)) + return drv.ExecContext(ctx, query, args...) +} + +// Query logs its params and calls the underlying driver Query method. +func (d *DebugDriver) Query(ctx context.Context, query string, args, v any) error { + start := time.Now() + err := d.Driver.Query(ctx, query, args, v) + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return err + } + d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v time=%v", query, args, time.Since(start))) + return err +} + +// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported. +func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + drv, ok := d.Driver.(interface { + QueryContext(context.Context, string, ...any) (*sql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.QueryContext is not supported") + } + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return drv.QueryContext(ctx, query, args...) + } + d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args)) + return drv.QueryContext(ctx, query, args...) +} + +// Tx adds an log-id for the transaction and calls the underlying driver Tx command. +func (d *DebugDriver) Tx(ctx context.Context) (dialect.Tx, error) { + tx, err := d.Driver.Tx(ctx) + if err != nil { + return nil, err + } + id := uuid.New().String() + d.log(ctx, fmt.Sprintf("driver.Tx(%s): started", id)) + return &DebugTx{tx, id, d.log, ctx}, nil +} + +// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported. +func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) { + drv, ok := d.Driver.(interface { + BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.BeginTx is not supported") + } + tx, err := drv.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + id := uuid.New().String() + d.log(ctx, fmt.Sprintf("driver.BeginTx(%s): started", id)) + return &DebugTx{tx, id, d.log, ctx}, nil +} + +// DebugTx is a transaction implementation that logs all transaction operations. +type DebugTx struct { + dialect.Tx // underlying transaction. + id string // transaction logging id. + log func(context.Context, ...any) // log function. defaults to fmt.Println. + ctx context.Context // underlying transaction context. +} + +// Exec logs its params and calls the underlying transaction Exec method. +func (d *DebugTx) Exec(ctx context.Context, query string, args, v any) error { + start := time.Now() + err := d.Tx.Exec(ctx, query, args, v) + printArgs := args + if argsArray, ok := args.([]interface{}); ok { + for i, argVal := range argsArray { + if argValStr, ok := argVal.(string); ok && len(argValStr) > strMaxLen { + printArgs.([]interface{})[i] = argValStr[:strMaxLen] + "...[Truncated]..." + } + } + } + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return err + } + d.log(ctx, fmt.Sprintf("Tx(%s).Exec: query=%v args=%v time=%v", d.id, query, args, time.Since(start))) + return err +} + +// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported. +func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + drv, ok := d.Tx.(interface { + ExecContext(context.Context, string, ...any) (sql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.ExecContext is not supported") + } + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return drv.ExecContext(ctx, query, args...) + } + d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args)) + return drv.ExecContext(ctx, query, args...) +} + +// Query logs its params and calls the underlying transaction Query method. +func (d *DebugTx) Query(ctx context.Context, query string, args, v any) error { + start := time.Now() + err := d.Tx.Query(ctx, query, args, v) + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return err + } + d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v time=%v", d.id, query, args, time.Since(start))) + return err +} + +// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported. +func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + drv, ok := d.Tx.(interface { + QueryContext(context.Context, string, ...any) (*sql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.QueryContext is not supported") + } + if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip { + return drv.QueryContext(ctx, query, args...) + } + d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args)) + return drv.QueryContext(ctx, query, args...) +} + +// Commit logs this step and calls the underlying transaction Commit method. +func (d *DebugTx) Commit() error { + d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id)) + return d.Tx.Commit() +} + +// Rollback logs this step and calls the underlying transaction Rollback method. +func (d *DebugTx) Rollback() error { + d.log(d.ctx, fmt.Sprintf("Tx(%s): rollbacked", d.id)) + return d.Tx.Rollback() +} diff --git a/inventory/direct_link.go b/inventory/direct_link.go new file mode 100644 index 00000000..dfed618b --- /dev/null +++ b/inventory/direct_link.go @@ -0,0 +1,70 @@ +package inventory + +import ( + "context" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" +) + +type ( + DirectLinkClient interface { + TxOperator + // GetByNameID get direct link by name and id + GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error) + // GetByID get direct link by id + GetByID(ctx context.Context, id int) (*ent.DirectLink, error) + } + LoadDirectLinkFile struct{} +) + +func NewDirectLinkClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DirectLinkClient { + return &directLinkClient{ + client: client, + hasher: hasher, + maxSQlParam: sqlParamLimit(dbType), + } +} + +type directLinkClient struct { + maxSQlParam int + client *ent.Client + hasher hashid.Encoder +} + +func (c *directLinkClient) SetClient(newClient *ent.Client) TxOperator { + return &directLinkClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam} +} + +func (c *directLinkClient) GetClient() *ent.Client { + return c.client +} + +func (d *directLinkClient) GetByID(ctx context.Context, id int) (*ent.DirectLink, error) { + return withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id))). + First(ctx) +} + +func (d *directLinkClient) GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error) { + res, err := withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id), directlink.Name(name))). + First(ctx) + if err != nil { + return nil, err + } + + // Increase download counter + _, _ = d.client.DirectLink.Update().Where(directlink.ID(res.ID)).SetDownloads(res.Downloads + 1).Save(ctx) + + return res, nil +} + +func withDirectLinkEagerLoading(ctx context.Context, q *ent.DirectLinkQuery) *ent.DirectLinkQuery { + if v, ok := ctx.Value(LoadDirectLinkFile{}).(bool); ok && v { + q.WithFile(func(m *ent.FileQuery) { + withFileEagerLoading(ctx, m) + }) + } + return q +} diff --git a/inventory/file.go b/inventory/file.go new file mode 100644 index 00000000..34b72718 --- /dev/null +++ b/inventory/file.go @@ -0,0 +1,1113 @@ +package inventory + +import ( + "context" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/directlink" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/schema" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/gofrs/uuid" + "github.com/samber/lo" + "golang.org/x/tools/container/intsets" +) + +const ( + RootFolderName = "" + SearchWildcard = "*" + MaxMetadataLen = 65535 +) + +type ( + // Ctx keys for eager loading options. + LoadFileEntity struct{} + LoadFileMetadata struct{} + LoadFilePublicMetadata struct{} + LoadFileShare struct{} + LoadFileUser struct{} + LoadFileDirectLink struct{} + LoadEntityUser struct{} + LoadEntityStoragePolicy struct{} + LoadEntityFile struct{} + + // Parameters for file list + ListFileParameters struct { + *PaginationArgs + // Whether to mix folder with files in results, only applied to cursor pagination + MixedType bool + // Whether to include only folder in results, only applied to cursor pagination + FolderOnly bool + // SharedWithMe indicates whether to list files shared with the user + SharedWithMe bool + Search *SearchFileParameters + } + + FlattenListFileParameters struct { + *PaginationArgs + UserID int + Name string + StoragePolicyID int + } + + SearchFileParameters struct { + Name []string + // NameOperatorOr is true if the name should match any of the given names, false if all of them + NameOperatorOr bool + Metadata map[string]string + Type *types.FileType + UseFullText bool + CaseFolding bool + Category string + SizeGte int64 + SizeLte int64 + CreatedAtGte *time.Time + CreatedAtLte *time.Time + UpdatedAtGte *time.Time + UpdatedAtLte *time.Time + } + + ListEntityParameters struct { + *PaginationArgs + EntityType *types.EntityType + UserID int + StoragePolicyID int + } + + ListEntityResult struct { + Entities []*ent.Entity + *PaginationResults + } + + ListFileResult struct { + Files []*ent.File + MixedType bool + *PaginationResults + } + + CreateFileParameters struct { + FileType types.FileType + Name string + Metadata map[string]string + MetadataPrivateMask map[string]bool + IsSymbolic bool + StoragePolicyID int + *EntityParameters + } + + CreateFolderParameters struct { + Owner int + Name string + IsSymbolic bool + Metadata map[string]string + MetadataPrivateMask map[string]bool + } + + EntityParameters struct { + ModifiedAt *time.Time + OwnerID int + EntityType types.EntityType + StoragePolicyID int + Source string + Size int64 + UploadSessionID uuid.UUID + } + + RelocateEntityParameter struct { + Entity *ent.Entity + NewSource string + ParentFiles []int + PrimaryEntityParentFiles []int + } +) + +type FileClient interface { + TxOperator + // GetByHashID returns a file by its hashID. + GetByHashID(ctx context.Context, hashID string) (*ent.File, error) + // GetParentFile returns the parent folder of a given file + GetParentFile(ctx context.Context, root *ent.File, eagerLoading bool) (*ent.File, error) + // Search file by name from a given root. eagerLoading indicates whether to load edges determined by ctx. + GetChildFile(ctx context.Context, root *ent.File, ownerID int, child string, eagerLoading bool) (*ent.File, error) + // Get all files under a given root + GetChildFiles(ctx context.Context, args *ListFileParameters, ownerID int, roots ...*ent.File) (*ListFileResult, error) + // Root returns the root folder of a given user + Root(ctx context.Context, user *ent.User) (*ent.File, error) + // CreateOrGetFolder creates a folder with given name under root, or return the existed one + CreateFolder(ctx context.Context, root *ent.File, args *CreateFolderParameters) (*ent.File, error) + // GetByIDs returns files with given IDs. The result is paginated. Caller should keep calling this until + // returned next page is negative. + GetByIDs(ctx context.Context, ids []int, page int) ([]*ent.File, int, error) + // GetByID returns a file by its ID. + GetByID(ctx context.Context, ids int) (*ent.File, error) + // GetEntitiesByIDs returns entities with given IDs. The result is paginated. Caller should keep calling this until + // returned next page is negative. + GetEntitiesByIDs(ctx context.Context, ids []int, page int) ([]*ent.Entity, int, error) + // GetEntityByID returns an entity by its ID. + GetEntityByID(ctx context.Context, id int) (*ent.Entity, error) + // Rename renames a file + Rename(ctx context.Context, original *ent.File, newName string) (*ent.File, error) + // SetParent sets parent of group of files + SetParent(ctx context.Context, files []*ent.File, parent *ent.File) error + // CreateFile creates a file with given parameters, returns created File, default Entity, and storage diff for owner. + CreateFile(ctx context.Context, root *ent.File, args *CreateFileParameters) (*ent.File, *ent.Entity, StorageDiff, error) + // UpgradePlaceholder upgrades a placeholder entity to a real version entity + UpgradePlaceholder(ctx context.Context, file *ent.File, modifiedAt *time.Time, entityId int, entityType types.EntityType) error + // RemoveMetadata removes metadata from a file + RemoveMetadata(ctx context.Context, file *ent.File, keys ...string) error + // CreateEntity creates an entity with given parameters, returns created Entity, and storage diff for owner. + CreateEntity(ctx context.Context, file *ent.File, args *EntityParameters) (*ent.Entity, StorageDiff, error) + // RemoveStaleEntities hard-delete stale placeholder entities + RemoveStaleEntities(ctx context.Context, file *ent.File) (StorageDiff, error) + // RemoveEntitiesByID hard-delete entities by IDs. + RemoveEntitiesByID(ctx context.Context, ids ...int) (map[int]int64, error) + // CapEntities caps the number of entities of a given file. The oldest entities will be unlinked + // if entity count exceed limit. + CapEntities(ctx context.Context, file *ent.File, owner *ent.User, max int, entityType types.EntityType) (StorageDiff, error) + // UpsertMetadata update or insert metadata + UpsertMetadata(ctx context.Context, file *ent.File, data map[string]string, privateMask map[string]bool) error + // Copy copies a layer of file to its corresponding destination folder. dstMap is a map from src parent ID to dst parent Files. + Copy(ctx context.Context, files []*ent.File, dstMap map[int][]*ent.File) (map[int][]*ent.File, StorageDiff, error) + // Delete deletes a group of files (and related models) with given entity recycle option + Delete(ctx context.Context, files []*ent.File, options *types.EntityRecycleOption) ([]*ent.Entity, StorageDiff, error) + // StaleEntities returns stale entities of a given file. If ID is not provided, all entities + // will be examined. + StaleEntities(ctx context.Context, ids ...int) ([]*ent.Entity, error) + // QueryMetadata load metadata of a given file + QueryMetadata(ctx context.Context, root *ent.File) error + // SoftDelete soft-deletes a file, also renaming it to a random name + SoftDelete(ctx context.Context, file *ent.File) error + // SetPrimaryEntity sets primary entity of a file + SetPrimaryEntity(ctx context.Context, file *ent.File, entityID int) error + // UnlinkEntity unlinks an entity from a file + UnlinkEntity(ctx context.Context, entity *ent.Entity, file *ent.File, owner *ent.User) (StorageDiff, error) + // CreateDirectLink creates a direct link for a file + CreateDirectLink(ctx context.Context, fileID int, name string, speed int) (*ent.DirectLink, error) + // CountByTimeRange counts files created in a given time range + CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) + // CountEntityByTimeRange counts entities created in a given time range + CountEntityByTimeRange(ctx context.Context, start, end *time.Time) (int, error) + // CountEntityByStoragePolicyID counts entities by storage policy ID + CountEntityByStoragePolicyID(ctx context.Context, storagePolicyID int) (int, int, error) + // IsStoragePolicyUsedByEntities checks if a storage policy is used by entities + IsStoragePolicyUsedByEntities(ctx context.Context, policyID int) (bool, error) + // DeleteByUser deletes all files by a given user + DeleteByUser(ctx context.Context, uid int) error + // FlattenListFiles list files ignoring hierarchy + FlattenListFiles(ctx context.Context, args *FlattenListFileParameters) (*ListFileResult, error) + // Update updates a file + Update(ctx context.Context, file *ent.File) (*ent.File, error) + // ListEntities lists entities + ListEntities(ctx context.Context, args *ListEntityParameters) (*ListEntityResult, error) +} + +func NewFileClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) FileClient { + return &fileClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher} +} + +type fileClient struct { + maxSQlParam int + client *ent.Client + hasher hashid.Encoder +} + +func (c *fileClient) SetClient(newClient *ent.Client) TxOperator { + return &fileClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher} +} + +func (c *fileClient) GetClient() *ent.Client { + return c.client +} + +func (f *fileClient) Update(ctx context.Context, file *ent.File) (*ent.File, error) { + q := f.client.File.UpdateOne(file). + SetName(file.Name). + SetStoragePoliciesID(file.StoragePolicyFiles) + + existingMetadata, err := f.client.Metadata.Query().Where(metadata.FileID(file.ID)).All(ctx) + if err != nil { + return nil, err + } + + metadataIDs := lo.Map(file.Edges.Metadata, func(item *ent.Metadata, index int) int { + return item.ID + }) + + existingMetadataIds := lo.Map(existingMetadata, func(item *ent.Metadata, index int) int { + return item.ID + }) + + metadataToDelete := lo.Without(existingMetadataIds, metadataIDs...) + + // process metadata diff, delete metadata not in file.Edges.Metadata + _, err = f.client.Metadata.Delete().Where(metadata.FileID(file.ID), metadata.IDIn(metadataToDelete...)).Exec(schema.SkipSoftDelete(ctx)) + if err != nil { + return nil, err + } + + // Update presented metadata + for _, metadata := range file.Edges.Metadata { + f.client.Metadata.UpdateOne(metadata). + SetName(metadata.Name). + SetValue(metadata.Value). + SetIsPublic(metadata.IsPublic). + Save(ctx) + } + + // process direct link diff, delete direct link not in file.Edges.DirectLinks + _, err = f.client.DirectLink.Delete().Where(directlink.FileID(file.ID), directlink.IDNotIn(lo.Map(file.Edges.DirectLinks, func(item *ent.DirectLink, index int) int { + return item.ID + })...)).Exec(ctx) + if err != nil { + return nil, err + } + + return q.Save(ctx) +} + +func (f *fileClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) { + if start == nil || end == nil { + return f.client.File.Query().Count(ctx) + } + + return f.client.File.Query().Where(file.CreatedAtGTE(*start), file.CreatedAtLT(*end)).Count(ctx) +} + +func (f *fileClient) CountEntityByTimeRange(ctx context.Context, start, end *time.Time) (int, error) { + if start == nil || end == nil { + return f.client.Entity.Query().Count(ctx) + } + + return f.client.Entity.Query().Where(entity.CreatedAtGTE(*start), entity.CreatedAtLT(*end)).Count(ctx) +} + +func (f *fileClient) CountEntityByStoragePolicyID(ctx context.Context, storagePolicyID int) (int, int, error) { + var v []struct { + Sum int `json:"sum"` + Count int `json:"count"` + } + + err := f.client.Entity.Query().Where(entity.StoragePolicyEntities(storagePolicyID)). + Aggregate( + ent.Sum(entity.FieldSize), + ent.Count(), + ).Scan(ctx, &v) + if err != nil { + return 0, 0, err + } + + return v[0].Count, v[0].Sum, nil +} + +func (f *fileClient) CreateDirectLink(ctx context.Context, file int, name string, speed int) (*ent.DirectLink, error) { + // Find existed + existed, err := f.client.DirectLink. + Query(). + Where(directlink.FileID(file), directlink.Name(name), directlink.Speed(speed)).First(ctx) + if err == nil { + return existed, nil + } + + return f.client.DirectLink. + Create(). + SetFileID(file). + SetName(name). + SetSpeed(speed). + SetDownloads(0). + Save(ctx) +} + +func (f *fileClient) GetByHashID(ctx context.Context, hashID string) (*ent.File, error) { + id, err := f.hasher.Decode(hashID, hashid.FileID) + if err != nil { + return nil, fmt.Errorf("ailed to decode hash id %q: %w", hashID, err) + } + + return withFileEagerLoading(ctx, f.client.File.Query().Where(file.ID(id))).First(ctx) +} + +func (f *fileClient) SoftDelete(ctx context.Context, file *ent.File) error { + newName := uuid.Must(uuid.NewV4()) + // Rename file to random UUID and make it stale + _, err := f.client.File.UpdateOne(file). + SetName(newName.String()). + ClearParent(). + Save(ctx) + if err != nil { + return fmt.Errorf("failed to soft delete file %d: %w", file.ID, err) + } + + return err +} + +func (f *fileClient) RemoveEntitiesByID(ctx context.Context, ids ...int) (map[int]int64, error) { + groups, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, ids) + // storageReduced stores the relation between owner ID and storage reduced. + storageReduced := make(map[int]int64) + + ctx = schema.SkipSoftDelete(ctx) + for _, group := range groups { + // For entities that are referenced by files, we need to reduce the storage of the owner of the files. + entities, err := f.client.Entity.Query(). + Where(group). + Where(entity.ReferenceCountGT(0)). + WithFile(). + All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query entities %v: %w", group, err) + } + + for _, entity := range entities { + if entity.Edges.File == nil { + continue + } + + for _, file := range entity.Edges.File { + storageReduced[file.OwnerID] -= entity.Size + } + } + + _, err = f.client.Entity.Delete(). + Where(group). + Exec(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query entities %v: %w", group, err) + } + } + + return storageReduced, nil +} + +func (f *fileClient) StaleEntities(ctx context.Context, ids ...int) ([]*ent.Entity, error) { + res := make([]*ent.Entity, 0, len(ids)) + if len(ids) > 0 { + // If explicit IDs are given, we can query them directly + groups, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, ids) + for _, group := range groups { + entities, err := f.client.Entity.Query(). + Where(group). + All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query entities %v: %w", group, err) + } + res = append(res, entities...) + } + + return res, nil + } + + // No explicit IDs are given, we need to query all entities + entities, err := f.client.Entity.Query(). + Where(entity.Or( + entity.ReferenceCountLTE(0), + )). + All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query stale entities: %w", err) + } + + return entities, nil +} + +func (f *fileClient) DeleteByUser(ctx context.Context, uid int) error { + batchSize := capPageSize(f.maxSQlParam, intsets.MaxInt, 10) + for { + files, err := f.client.File.Query(). + WithEntities(). + Where(file.OwnerID(uid)). + Limit(batchSize). + All(ctx) + if err != nil { + return fmt.Errorf("failed to query files: %w", err) + } + + if len(files) == 0 { + break + } + + if _, _, err := f.Delete(ctx, files, nil); err != nil { + return fmt.Errorf("failed to delete files: %w", err) + } + } + + return nil +} + +func (f *fileClient) Delete(ctx context.Context, files []*ent.File, options *types.EntityRecycleOption) ([]*ent.Entity, StorageDiff, error) { + // 1. Decrease reference count for all entities; + // entities stores the relation between its reference count in `files` and entity ID. + entities := make(map[int]int) + // storageReduced stores the relation between owner ID and storage reduced. + storageReduced := make(map[int]int64) + for _, fi := range files { + fileEntities, err := fi.Edges.EntitiesOrErr() + if err != nil { + return nil, nil, err + } + + for _, e := range fileEntities { + entities[e.ID]++ + storageReduced[fi.OwnerID] -= e.Size + } + } + + // Group entities by their reference count. + uniqueEntities := lo.Keys(entities) + entitiesGrouped := lo.GroupBy(uniqueEntities, func(e int) int { + return entities[e] + }) + + for ref, entityGroup := range entitiesGrouped { + entityPageGroup, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, entityGroup) + for _, group := range entityPageGroup { + if err := f.client.Entity.Update(). + Where(group). + AddReferenceCount(-1 * ref). + Exec(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to decrease reference count for entities %v: %w", group, err) + } + } + } + + // 2. Filter out entities with <=0 reference count, Update recycle options for above entities; + entityGroup, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, uniqueEntities) + toBeRecycled := make([]*ent.Entity, 0, len(entities)) + for _, group := range entityGroup { + e, err := f.client.Entity.Query().Where(group).Where(entity.ReferenceCountLTE(0)).All(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to query orphan entities %v: %w", group, err) + } + + toBeRecycled = append(toBeRecycled, e...) + } + + // 3. Update recycle options for above entities; + pageSize := capPageSize(f.maxSQlParam, intsets.MaxInt, 10) + chunks := lo.Chunk(lo.Map(toBeRecycled, func(item *ent.Entity, index int) int { + return item.ID + }), max(pageSize, 1)) + for _, chunk := range chunks { + if err := f.client.Entity.Update(). + Where(entity.IDIn(chunk...)). + SetRecycleOptions(options). + Exec(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to update recycle options for entities %v: %w", chunk, err) + } + } + + hardDeleteCtx := schema.SkipSoftDelete(ctx) + fileGroups, chunks := f.batchInCondition(intsets.MaxInt, 10, 1, + lo.Map(files, func(file *ent.File, index int) int { + return file.ID + }), + ) + + for i, group := range fileGroups { + // 4. Delete shares/metadata/directlinks if needed; + if _, err := f.client.Share.Delete().Where(share.HasFileWith(group)).Exec(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to delete shares of files %v: %w", group, err) + } + + if _, err := f.client.Metadata.Delete().Where(metadata.FileIDIn(chunks[i]...)).Exec(schema.SkipSoftDelete(ctx)); err != nil { + return nil, nil, fmt.Errorf("failed to delete metadata of files %v: %w", group, err) + } + + if _, err := f.client.DirectLink.Delete().Where(directlink.FileIDIn(chunks[i]...)).Exec(hardDeleteCtx); err != nil { + return nil, nil, fmt.Errorf("failed to delete direct links of files %v: %w", group, err) + } + + // 5. Delete files. + if _, err := f.client.File.Delete().Where(group).Exec(hardDeleteCtx); err != nil { + return nil, nil, fmt.Errorf("failed to delete files %v: %w", group, err) + } + } + + return toBeRecycled, storageReduced, nil +} + +func (f *fileClient) Copy(ctx context.Context, files []*ent.File, dstMap map[int][]*ent.File) (map[int][]*ent.File, StorageDiff, error) { + pageSize := capPageSize(f.maxSQlParam, intsets.MaxInt, 10) + // 1. Copy files and metadata + copyFileStm := lo.Map(files, func(file *ent.File, index int) *ent.FileCreate { + + stm := f.client.File.Create(). + SetName(file.Name). + SetOwnerID(dstMap[file.FileChildren][0].OwnerID). + SetSize(file.Size). + SetType(file.Type). + SetParent(dstMap[file.FileChildren][0]). + SetIsSymbolic(file.IsSymbolic) + if file.StoragePolicyFiles > 0 { + stm.SetStoragePolicyFiles(file.StoragePolicyFiles) + } + if file.PrimaryEntity > 0 { + stm.SetPrimaryEntity(file.PrimaryEntity) + } + + return stm + }) + + metadataStm := []*ent.MetadataCreate{} + entityStm := []*ent.EntityUpdate{} + newDstMap := make(map[int][]*ent.File, len(files)) + sizeDiff := int64(0) + for index, stm := range copyFileStm { + newFile, err := stm.Save(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to copy file: %w", err) + } + + fileMetadata, err := files[index].Edges.MetadataOrErr() + if err != nil { + return nil, nil, fmt.Errorf("failed to get metadata of file: %w", err) + } + + metadataStm = append(metadataStm, lo.Map(fileMetadata, func(metadata *ent.Metadata, index int) *ent.MetadataCreate { + return f.client.Metadata.Create(). + SetName(metadata.Name). + SetValue(metadata.Value). + SetFile(newFile). + SetIsPublic(metadata.IsPublic) + })...) + + fileEntities, err := files[index].Edges.EntitiesOrErr() + if err != nil { + return nil, nil, fmt.Errorf("failed to get entities of file: %w", err) + } + ids := lo.FilterMap(fileEntities, func(entity *ent.Entity, index int) (int, bool) { + // Skip entities that are still uploading + if entity.UploadSessionID != nil { + return 0, false + } + sizeDiff += entity.Size + return entity.ID, true + }) + + entityBatch, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, ids) + entityStm = append(entityStm, lo.Map(entityBatch, func(batch predicate.Entity, index int) *ent.EntityUpdate { + return f.client.Entity.Update().Where(batch).AddReferenceCount(1).AddFile(newFile) + })...) + + newAncestorChain := append(newDstMap[files[index].ID], newFile) + newAncestorChain[0], newAncestorChain[len(newAncestorChain)-1] = newAncestorChain[len(newAncestorChain)-1], newAncestorChain[0] + newDstMap[files[index].ID] = newAncestorChain + } + + // 2. Copy metadata by group + chunkedMetadataStm := lo.Chunk(metadataStm, pageSize/10) + for _, chunk := range chunkedMetadataStm { + if err := f.client.Metadata.CreateBulk(chunk...).Exec(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to copy metadata: %w", err) + } + } + + // 3. Copy entity relations + for _, stm := range entityStm { + if err := stm.Exec(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to copy entity: %w", err) + } + } + + return newDstMap, map[int]int64{dstMap[files[0].FileChildren][0].OwnerID: sizeDiff}, nil +} + +func (f *fileClient) UpsertMetadata(ctx context.Context, file *ent.File, data map[string]string, privateMask map[string]bool) error { + // Validate value length + for key, value := range data { + if len(value) > MaxMetadataLen { + return fmt.Errorf("metadata value of key %s is too long", key) + } + } + + if err := f.client.Metadata. + CreateBulk(lo.MapToSlice(data, func(key string, value string) *ent.MetadataCreate { + isPrivate := false + if privateMask != nil { + _, isPrivate = privateMask[key] + } + + return f.client.Metadata.Create(). + SetName(key). + SetValue(value). + SetFile(file). + SetIsPublic(!isPrivate). + SetNillableDeletedAt(nil) + })...). + OnConflictColumns(metadata.FieldFileID, metadata.FieldName). + UpdateNewValues(). + Exec(ctx); err != nil { + return fmt.Errorf("failed to upsert metadata: %w", err) + } + + return nil +} + +func (f *fileClient) RemoveMetadata(ctx context.Context, file *ent.File, keys ...string) error { + if len(keys) == 0 { + return nil + } + ctx = schema.SkipSoftDelete(ctx) + groups, _ := f.batchInConditionMetadataName(intsets.MaxInt, 10, 1, keys) + for _, group := range groups { + if _, err := f.client.Metadata.Delete().Where(metadata.FileID(file.ID), group).Exec(ctx); err != nil { + return fmt.Errorf("failed to remove metadata: %v", err) + } + } + + return nil +} + +func (f *fileClient) UpgradePlaceholder(ctx context.Context, file *ent.File, modifiedAt *time.Time, entityId int, + entityType types.EntityType) error { + entities, err := file.Edges.EntitiesOrErr() + if err != nil { + return err + } + + placeholder, found := lo.Find(entities, func(e *ent.Entity) bool { + return e.ID == entityId + }) + if !found { + return fmt.Errorf("no identity with id %d entity for file %d", entityId, file.ID) + } + + stm := f.client.Entity. + UpdateOne(placeholder). + ClearUploadSessionID() + if modifiedAt != nil { + stm.SetUpdatedAt(*modifiedAt) + } + + if err := stm.Exec(ctx); err != nil { + return fmt.Errorf("failed to upgrade placeholder: %v", err) + } + + if entityType == types.EntityTypeVersion { + if err := f.client.File.UpdateOne(file). + SetSize(placeholder.Size). + SetPrimaryEntity(placeholder.ID). + Exec(ctx); err != nil { + return fmt.Errorf("failed to upgrade file primary entity: %v", err) + } + } + return nil +} + +func (f *fileClient) SetPrimaryEntity(ctx context.Context, file *ent.File, entityID int) error { + return f.client.File.UpdateOne(file).SetPrimaryEntity(entityID).Exec(ctx) +} + +func (f *fileClient) CreateFile(ctx context.Context, root *ent.File, args *CreateFileParameters) (*ent.File, *ent.Entity, StorageDiff, error) { + var defaultEntity *ent.Entity + stm := f.client.File. + Create(). + SetOwnerID(root.OwnerID). + SetType(int(args.FileType)). + SetName(args.Name). + SetParent(root). + SetIsSymbolic(args.IsSymbolic). + SetStoragePoliciesID(args.StoragePolicyID) + newFile, err := stm.Save(ctx) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create file: %v", err) + } + + // Create default primary file entity if needed + var storageDiff StorageDiff + if args.EntityParameters != nil { + args.EntityParameters.OwnerID = root.OwnerID + args.EntityParameters.StoragePolicyID = args.StoragePolicyID + defaultEntity, storageDiff, err = f.CreateEntity(ctx, newFile, args.EntityParameters) + if err != nil { + return nil, nil, storageDiff, fmt.Errorf("failed to create default entity: %v", err) + } + } + + // Create metadata if needed + if len(args.Metadata) > 0 { + _, err := f.client.Metadata. + CreateBulk(lo.MapToSlice(args.Metadata, func(key, value string) *ent.MetadataCreate { + _, isPrivate := args.MetadataPrivateMask[key] + return f.client.Metadata.Create(). + SetName(key). + SetValue(value). + SetFile(newFile). + SetIsPublic(!isPrivate) + })...). + Save(ctx) + if err != nil { + return nil, nil, storageDiff, fmt.Errorf("failed to create metadata: %v", err) + } + } + + return newFile, defaultEntity, storageDiff, err + +} + +func (f *fileClient) CapEntities(ctx context.Context, file *ent.File, owner *ent.User, max int, entityType types.EntityType) (StorageDiff, error) { + entities, err := file.Edges.EntitiesOrErr() + if err != nil { + return nil, fmt.Errorf("failed to cap file entities: %v", err) + } + + versionCount := 0 + diff := make(StorageDiff) + for _, e := range entities { + if e.Type != int(entityType) { + continue + } + + versionCount++ + if versionCount > max { + // By default, eager-loaded entity is sorted by ID in descending order. + // So we can just unlink the entity and it will be the older version. + newDiff, err := f.UnlinkEntity(ctx, e, file, owner) + if err != nil { + return diff, fmt.Errorf("failed to cap file entities: %v", err) + } + + diff.Merge(newDiff) + } + } + + return diff, nil +} + +func (f *fileClient) UnlinkEntity(ctx context.Context, entity *ent.Entity, file *ent.File, owner *ent.User) (StorageDiff, error) { + if err := f.client.Entity.UpdateOne(entity).RemoveFile(file).AddReferenceCount(-1).Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to unlink entity: %v", err) + } + + return map[int]int64{owner.ID: entity.Size * int64(-1)}, nil +} + +func (f *fileClient) IsStoragePolicyUsedByEntities(ctx context.Context, policyID int) (bool, error) { + res, err := f.client.Entity.Query().Where(entity.StoragePolicyEntities(policyID)).Limit(1).All(ctx) + if err != nil { + return false, fmt.Errorf("failed to check if storage policy is used by entities: %v", err) + } + + if len(res) > 0 { + return true, nil + } + + return false, nil +} + +func (f *fileClient) RemoveStaleEntities(ctx context.Context, file *ent.File) (StorageDiff, error) { + entities, err := file.Edges.EntitiesOrErr() + if err != nil { + return nil, fmt.Errorf("failed to get stale entities: %v", err) + } + + sizeReduced := int64(0) + ids := lo.FilterMap(entities, func(e *ent.Entity, index int) (int, bool) { + if e.ReferenceCount == 1 && e.UploadSessionID != nil { + sizeReduced += e.Size + return e.ID, true + } + + return 0, false + }) + + if len(ids) > 0 { + if err = f.client.Entity.Update(). + Where(entity.IDIn(ids...)). + RemoveFile(file). + AddReferenceCount(-1).Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to remove stale entities: %v", err) + } + } + + return map[int]int64{file.OwnerID: sizeReduced * -1}, nil +} + +func (f *fileClient) CreateEntity(ctx context.Context, file *ent.File, args *EntityParameters) (*ent.Entity, StorageDiff, error) { + createdBy := UserFromContext(ctx) + stm := f.client.Entity. + Create(). + SetType(int(args.EntityType)). + SetSource(args.Source). + SetSize(args.Size). + SetStoragePolicyID(args.StoragePolicyID) + + if createdBy != nil && !IsAnonymousUser(createdBy) { + stm.SetUser(createdBy) + } + + if args.ModifiedAt != nil { + stm.SetUpdatedAt(*args.ModifiedAt) + } + + if args.UploadSessionID != uuid.Nil { + stm.SetUploadSessionID(args.UploadSessionID) + } + + created, err := stm.Save(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to create file entity: %v", err) + } + + diff := map[int]int64{file.OwnerID: created.Size} + + if err := f.client.File.UpdateOne(file).AddEntities(created).Exec(ctx); err != nil { + return nil, diff, fmt.Errorf("failed to add file entity: %v", err) + } + + return created, diff, nil +} + +func (f *fileClient) SetParent(ctx context.Context, files []*ent.File, parent *ent.File) error { + groups, _ := f.batchInCondition(intsets.MaxInt, 10, 1, lo.Map(files, func(file *ent.File, index int) int { + return file.ID + })) + for _, group := range groups { + _, err := f.client.File.Update().SetParent(parent).Where(group).Save(ctx) + if err != nil { + return fmt.Errorf("failed to set parent field: %w", err) + } + } + + return nil +} + +func (f *fileClient) GetParentFile(ctx context.Context, root *ent.File, eagerLoading bool) (*ent.File, error) { + query := f.client.File.QueryParent(root) + if eagerLoading { + query = withFileEagerLoading(ctx, query) + } + + return query.First(ctx) +} + +func (f *fileClient) QueryMetadata(ctx context.Context, root *ent.File) error { + metadata, err := f.client.File.QueryMetadata(root).All(ctx) + if err != nil { + return err + } + root.SetMetadata(metadata) + return nil +} + +func (f *fileClient) GetChildFile(ctx context.Context, root *ent.File, ownerID int, child string, eagerLoading bool) (*ent.File, error) { + query := f.childFileQuery(ownerID, false, root) + if eagerLoading { + query = withFileEagerLoading(ctx, query) + } + + return query. + Where(file.Name(child)).First(ctx) +} + +func (f *fileClient) GetChildFiles(ctx context.Context, args *ListFileParameters, ownerID int, roots ...*ent.File) (*ListFileResult, error) { + rawQuery := f.childFileQuery(ownerID, args.SharedWithMe, roots...) + query := withFileEagerLoading(ctx, rawQuery) + if args.Search != nil { + query = f.searchQuery(query, args.Search, roots, ownerID) + } + + var ( + files []*ent.File + err error + paginationRes *PaginationResults + ) + + if args.UseCursorPagination || args.Search != nil { + files, paginationRes, err = f.cursorPagination(ctx, query, args, 10) + } else { + files, paginationRes, err = f.offsetPagination(ctx, query, args, 10) + } + + if err != nil { + return nil, fmt.Errorf("query failed with paginiation: %w", err) + } + + return &ListFileResult{ + Files: files, + PaginationResults: paginationRes, + MixedType: args.MixedType, + }, nil +} + +func (f *fileClient) Root(ctx context.Context, user *ent.User) (*ent.File, error) { + return f.client.User. + QueryFiles(user). + Where(file.Not(file.HasParent())). + Where(file.Name(RootFolderName)). + First(ctx) +} + +func (f *fileClient) CreateFolder(ctx context.Context, root *ent.File, args *CreateFolderParameters) (*ent.File, error) { + stm := f.client.File. + Create(). + SetOwnerID(args.Owner). + SetType(int(types.FileTypeFolder)). + SetIsSymbolic(args.IsSymbolic). + SetName(args.Name) + if root != nil { + stm.SetParent(root).SetType(int(types.FileTypeFolder)) + } + + fid, err := stm.OnConflict(sql.ConflictColumns(file.FieldFileChildren, file.FieldName)).Ignore().ID(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create folder: %w", err) + } + + newFolder, err := f.client.File.Get(ctx, fid) + if err != nil { + return nil, fmt.Errorf("failed to get folder: %w", err) + } + + if len(args.Metadata) > 0 { + _, err := f.client.Metadata. + CreateBulk(lo.MapToSlice(args.Metadata, func(key, value string) *ent.MetadataCreate { + _, isPrivate := args.MetadataPrivateMask[key] + return f.client.Metadata.Create(). + SetName(key). + SetValue(value). + SetFile(newFolder). + SetIsPublic(!isPrivate) + })...). + Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create metadata: %v", err) + } + } + + return newFolder, err + +} + +func (f *fileClient) Rename(ctx context.Context, original *ent.File, newName string) (*ent.File, error) { + return f.client.File.UpdateOne(original).SetName(newName).Save(ctx) +} + +func (f *fileClient) GetEntitiesByIDs(ctx context.Context, ids []int, page int) ([]*ent.Entity, int, error) { + groups, _ := f.batchInConditionEntityID(intsets.MaxInt, 10, 1, ids) + if page >= len(groups) || page < 0 { + return nil, -1, fmt.Errorf("page out of range") + } + + res, err := f.client.Entity.Query().Where(groups[page]).All(ctx) + if err != nil { + return nil, page, err + } + + if page+1 >= len(groups) { + return res, -1, nil + } + + return res, page + 1, nil +} + +func (f *fileClient) GetEntityByID(ctx context.Context, id int) (*ent.Entity, error) { + return withEntityEagerLoading(ctx, f.client.Entity.Query().Where(entity.ID(id))).First(ctx) +} + +func (f *fileClient) GetByID(ctx context.Context, ids int) (*ent.File, error) { + return withFileEagerLoading(ctx, f.client.File.Query().Where(file.ID(ids))).First(ctx) +} + +func (f *fileClient) GetByIDs(ctx context.Context, ids []int, page int) ([]*ent.File, int, error) { + groups, _ := f.batchInCondition(intsets.MaxInt, 10, 1, ids) + if page >= len(groups) || page < 0 { + return nil, -1, fmt.Errorf("page out of range") + } + + res, err := withFileEagerLoading(ctx, f.client.File.Query().Where(groups[page])).All(ctx) + if err != nil { + return nil, page, err + } + + if page+1 >= len(groups) { + return res, -1, nil + } + + return res, page + 1, nil +} + +func (f *fileClient) FlattenListFiles(ctx context.Context, args *FlattenListFileParameters) (*ListFileResult, error) { + query := f.client.File.Query().Where(file.Type(int(types.FileTypeFile)), file.IsSymbolic(false)) + + if args.UserID > 0 { + query = query.Where(file.OwnerID(args.UserID)) + } + + if args.StoragePolicyID > 0 { + query = query.Where(file.StoragePolicyFiles(args.StoragePolicyID)) + } + + if args.Name != "" { + query = query.Where(file.NameContainsFold(args.Name)) + } + + query.Order(getFileOrderOption(&ListFileParameters{ + PaginationArgs: args.PaginationArgs, + })...) + + // Count total items + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to count files: %w", err) + } + + files, err := withFileEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list files: %w", err) + } + + return &ListFileResult{ + Files: files, + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: args.PageSize, + }, + }, nil +} + +func (f *fileClient) ListEntities(ctx context.Context, args *ListEntityParameters) (*ListEntityResult, error) { + query := f.client.Entity.Query() + if args.EntityType != nil { + query = query.Where(entity.Type(int(*args.EntityType))) + } + + if args.UserID > 0 { + query = query.Where(entity.CreatedBy(args.UserID)) + } + + if args.StoragePolicyID > 0 { + query = query.Where(entity.StoragePolicyEntities(args.StoragePolicyID)) + } + + query.Order(getEntityOrderOption(args)...) + + // Count total items + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to count files: %w", err) + } + + entities, err := withEntityEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list entities: %w", err) + } + + return &ListEntityResult{ + Entities: entities, + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: args.PageSize, + }, + }, nil +} diff --git a/inventory/file_utils.go b/inventory/file_utils.go new file mode 100644 index 00000000..f141c975 --- /dev/null +++ b/inventory/file_utils.go @@ -0,0 +1,503 @@ +package inventory + +import ( + "context" + "fmt" + "strings" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/entity" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/metadata" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/samber/lo" +) + +func (f *fileClient) searchQuery(q *ent.FileQuery, args *SearchFileParameters, parents []*ent.File, ownerId int) *ent.FileQuery { + if len(parents) == 1 && parents[0] == nil { + q = q.Where(file.OwnerID(ownerId)) + } else { + q = q.Where( + file.HasParentWith( + file.IDIn(lo.Map(parents, func(item *ent.File, index int) int { + return item.ID + })..., + ), + ), + ) + } + + if len(args.Name) > 0 { + namePredicates := lo.Map(args.Name, func(item string, index int) predicate.File { + // If start and ends with quotes, treat as exact match + if strings.HasPrefix(item, "\"") && strings.HasSuffix(item, "\"") { + return file.NameContains(strings.Trim(item, "\"")) + } + + // if contain wildcard, use transform to sql like + if strings.Contains(item, SearchWildcard) { + pattern := strings.ReplaceAll(item, SearchWildcard, "%") + if pattern[0] != '%' && pattern[len(pattern)-1] != '%' { + // if not start with wildcard, add prefix wildcard + pattern = "%" + pattern + "%" + } + + return func(s *sql.Selector) { + s.Where(sql.Like(file.FieldName, pattern)) + } + } + + if args.CaseFolding { + return file.NameContainsFold(item) + } + + return file.NameContains(item) + }) + + if args.NameOperatorOr { + q = q.Where(file.Or(namePredicates...)) + } else { + q = q.Where(file.And(namePredicates...)) + } + } + + if args.Type != nil { + q = q.Where(file.TypeEQ(int(*args.Type))) + } + + if len(args.Metadata) > 0 { + metaPredicates := lo.MapToSlice(args.Metadata, func(name string, value string) predicate.Metadata { + nameEq := metadata.NameEQ(value) + if name == "" { + return nameEq + } else { + valueContain := metadata.ValueContainsFold(value) + return metadata.And(metadata.NameEQ(name), valueContain) + } + }) + metaPredicates = append(metaPredicates, metadata.IsPublic(true)) + q.Where(file.HasMetadataWith(metadata.And(metaPredicates...))) + } + + if args.SizeLte > 0 || args.SizeGte > 0 { + q = q.Where(file.SizeGTE(args.SizeGte), file.SizeLTE(args.SizeLte)) + } + + if args.CreatedAtLte != nil { + q = q.Where(file.CreatedAtLTE(*args.CreatedAtLte)) + } + + if args.CreatedAtGte != nil { + q = q.Where(file.CreatedAtGTE(*args.CreatedAtGte)) + } + + if args.UpdatedAtLte != nil { + q = q.Where(file.UpdatedAtLTE(*args.UpdatedAtLte)) + } + + if args.UpdatedAtGte != nil { + q = q.Where(file.UpdatedAtGTE(*args.UpdatedAtGte)) + } + + return q +} + +// ChildFileQuery generates query for child file(s) of a given set of root +func (f *fileClient) childFileQuery(ownerID int, isSymbolic bool, root ...*ent.File) *ent.FileQuery { + rawQuery := f.client.File.Query() + if len(root) == 1 && root[0] != nil { + // Query children of one single root + rawQuery = f.client.File.QueryChildren(root[0]) + } else if root[0] == nil { + // Query orphan files with owner ID + predicates := []predicate.File{ + file.NameNEQ(RootFolderName), + } + + if ownerID > 0 { + predicates = append(predicates, file.OwnerIDEQ(ownerID)) + } + + if isSymbolic { + predicates = append(predicates, file.And(file.IsSymbolic(true), file.FileChildrenNotNil())) + } else { + predicates = append(predicates, file.Not(file.HasParent())) + } + + rawQuery = f.client.File.Query().Where( + file.And(predicates...), + ) + } else { + // Query children of multiple roots + rawQuery. + Where( + file.HasParentWith( + file.IDIn(lo.Map(root, func(item *ent.File, index int) int { + return item.ID + })...), + ), + ) + } + + return rawQuery +} + +// batchInCondition returns a list of predicates that divide original group into smaller ones +// to bypass DB limitations. +func (f *fileClient) batchInCondition(pageSize, margin int, multiply int, ids []int) ([]predicate.File, [][]int) { + pageSize = capPageSize(f.maxSQlParam, pageSize, margin) + chunks := lo.Chunk(ids, max(pageSize/multiply, 1)) + return lo.Map(chunks, func(item []int, index int) predicate.File { + return file.IDIn(item...) + }), chunks +} + +func (f *fileClient) batchInConditionMetadataName(pageSize, margin int, multiply int, keys []string) ([]predicate.Metadata, [][]string) { + pageSize = capPageSize(f.maxSQlParam, pageSize, margin) + chunks := lo.Chunk(keys, max(pageSize/multiply, 1)) + return lo.Map(chunks, func(item []string, index int) predicate.Metadata { + return metadata.NameIn(item...) + }), chunks +} + +func (f *fileClient) batchInConditionEntityID(pageSize, margin int, multiply int, keys []int) ([]predicate.Entity, [][]int) { + pageSize = capPageSize(f.maxSQlParam, pageSize, margin) + chunks := lo.Chunk(keys, max(pageSize/multiply, 1)) + return lo.Map(chunks, func(item []int, index int) predicate.Entity { + return entity.IDIn(item...) + }), chunks +} + +// cursorPagination perform pagination with cursor, which is faster than fast pagination, but less flexible. +func (f *fileClient) cursorPagination(ctx context.Context, query *ent.FileQuery, + args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) { + pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin) + query.Order(getFileOrderOption(args)...) + currentPage := 0 + // Three types of query option + queryPaged := []*ent.FileQuery{ + query.Clone(). + Where(file.TypeEQ(int(types.FileTypeFolder))), + query.Clone(). + Where(file.TypeEQ(int(types.FileTypeFile))), + query.Clone(). + Where(file.TypeIn(int(types.FileTypeFolder), int(types.FileTypeFile))), + } + + var ( + pageToken *PageToken + err error + ) + if args.PageToken != "" { + pageToken, err = pageTokenFromString(args.PageToken, f.hasher, hashid.FileID) + if err != nil { + return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err) + } + } + queryPaged = getFileCursorQuery(args, pageToken, queryPaged) + + // Use page size + 1 to determine if there are more items to come + queryPaged[0].Limit(pageSize + 1) + + files, err := queryPaged[0]. + All(ctx) + if err != nil { + return nil, nil, err + } + + nextStartWithFile := false + if pageToken != nil && pageToken.StartWithFile { + nextStartWithFile = true + } + if len(files) < pageSize+1 && len(queryPaged) > 1 && !args.MixedType && !args.FolderOnly { + queryPaged[1].Limit(pageSize + 1 - len(files)) + filesContinue, err := queryPaged[1]. + All(ctx) + if err != nil { + return nil, nil, err + } + + nextStartWithFile = true + files = append(files, filesContinue...) + } + + // More items to come + nextTokenStr := "" + if len(files) > pageSize { + lastItem := files[len(files)-2] + nextToken, err := getFileNextPageToken(f.hasher, lastItem, args, nextStartWithFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate next page token: %w", err) + } + + nextTokenStr = nextToken + } + + return lo.Subset(files, 0, uint(pageSize)), &PaginationResults{ + Page: currentPage, + PageSize: pageSize, + NextPageToken: nextTokenStr, + IsCursor: true, + }, nil + +} + +// offsetPagination perform traditional pagination with minor optimizations. +func (f *fileClient) offsetPagination(ctx context.Context, query *ent.FileQuery, + args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) { + pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin) + queryWithoutOrder := query.Clone() + query.Order(getFileOrderOption(args)...) + + // Count total items by type + var v []struct { + Type int `json:"type"` + Count int `json:"count"` + } + err := queryWithoutOrder.Clone(). + GroupBy(file.FieldType). + Aggregate(ent.Count()). + Scan(ctx, &v) + if err != nil { + return nil, nil, err + } + + folderCount := 0 + fileCount := 0 + for _, item := range v { + if item.Type == int(types.FileTypeFolder) { + folderCount = item.Count + } else { + fileCount = item.Count + } + } + + allFiles := make([]*ent.File, 0, pageSize) + folderLimit := 0 + if (args.Page+1)*pageSize > folderCount { + folderLimit = folderCount - args.Page*pageSize + if folderLimit < 0 { + folderLimit = 0 + } + } else { + folderLimit = pageSize + } + + if folderLimit <= pageSize && folderLimit > 0 { + // Folder still remains + folders, err := query.Clone(). + Limit(folderLimit). + Offset(args.Page * pageSize). + Where(file.TypeEQ(int(types.FileTypeFolder))).All(ctx) + if err != nil { + return nil, nil, err + } + + allFiles = append(allFiles, folders...) + } + + if folderLimit < pageSize { + files, err := query.Clone(). + Limit(pageSize - folderLimit). + Offset((args.Page * pageSize) + folderLimit - folderCount). + Where(file.TypeEQ(int(types.FileTypeFile))). + All(ctx) + if err != nil { + return nil, nil, err + } + + allFiles = append(allFiles, files...) + } + + return allFiles, &PaginationResults{ + TotalItems: folderCount + fileCount, + Page: args.Page, + PageSize: pageSize, + }, nil +} + +func withFileEagerLoading(ctx context.Context, q *ent.FileQuery) *ent.FileQuery { + if v, ok := ctx.Value(LoadFileEntity{}).(bool); ok && v { + q.WithEntities(func(m *ent.EntityQuery) { + m.Order(ent.Desc(entity.FieldID)) + withEntityEagerLoading(ctx, m) + }) + } + if v, ok := ctx.Value(LoadFileMetadata{}).(bool); ok && v { + q.WithMetadata() + } + if v, ok := ctx.Value(LoadFilePublicMetadata{}).(bool); ok && v { + q.WithMetadata(func(m *ent.MetadataQuery) { + m.Where(metadata.IsPublic(true)) + }) + } + if v, ok := ctx.Value(LoadFileShare{}).(bool); ok && v { + q.WithShares() + } + if v, ok := ctx.Value(LoadFileUser{}).(bool); ok && v { + q.WithOwner(func(query *ent.UserQuery) { + withUserEagerLoading(ctx, query) + }) + } + if v, ok := ctx.Value(LoadFileDirectLink{}).(bool); ok && v { + q.WithDirectLinks() + } + + return q +} + +func withEntityEagerLoading(ctx context.Context, q *ent.EntityQuery) *ent.EntityQuery { + if v, ok := ctx.Value(LoadEntityUser{}).(bool); ok && v { + q.WithUser() + } + + if v, ok := ctx.Value(LoadEntityStoragePolicy{}).(bool); ok && v { + q.WithStoragePolicy() + } + + if v, ok := ctx.Value(LoadEntityFile{}).(bool); ok && v { + q.WithFile(func(fq *ent.FileQuery) { + withFileEagerLoading(ctx, fq) + }) + } + + return q +} + +func getFileOrderOption(args *ListFileParameters) []file.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case file.FieldName: + return []file.OrderOption{file.ByName(orderTerm), file.ByID(orderTerm)} + case file.FieldSize: + return []file.OrderOption{file.BySize(orderTerm), file.ByID(orderTerm)} + case file.FieldUpdatedAt: + return []file.OrderOption{file.ByUpdatedAt(orderTerm), file.ByID(orderTerm)} + default: + return []file.OrderOption{file.ByID(orderTerm)} + } +} + +func getEntityOrderOption(args *ListEntityParameters) []entity.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case entity.FieldSize: + return []entity.OrderOption{entity.BySize(orderTerm), entity.ByID(orderTerm)} + case entity.FieldUpdatedAt: + return []entity.OrderOption{entity.ByUpdatedAt(orderTerm), entity.ByID(orderTerm)} + case entity.FieldReferenceCount: + return []entity.OrderOption{entity.ByReferenceCount(orderTerm), entity.ByID(orderTerm)} + default: + return []entity.OrderOption{entity.ByID(orderTerm)} + } +} + +var fileCursorQuery = map[string]map[bool]func(token *PageToken) predicate.File{ + file.FieldName: { + true: func(token *PageToken) predicate.File { + return file.Or( + file.NameLT(token.String), + file.And(file.Name(token.String), file.IDLT(token.ID)), + ) + }, + false: func(token *PageToken) predicate.File { + return file.Or( + file.NameGT(token.String), + file.And(file.Name(token.String), file.IDGT(token.ID)), + ) + }, + }, + file.FieldSize: { + true: func(token *PageToken) predicate.File { + return file.Or( + file.SizeLT(int64(token.Int)), + file.And(file.Size(int64(token.Int)), file.IDLT(token.ID)), + ) + }, + false: func(token *PageToken) predicate.File { + return file.Or( + file.SizeGT(int64(token.Int)), + file.And(file.Size(int64(token.Int)), file.IDGT(token.ID)), + ) + }, + }, + file.FieldCreatedAt: { + true: func(token *PageToken) predicate.File { + return file.IDLT(token.ID) + }, + false: func(token *PageToken) predicate.File { + return file.IDGT(token.ID) + }, + }, + file.FieldUpdatedAt: { + true: func(token *PageToken) predicate.File { + return file.Or( + file.UpdatedAtLT(*token.Time), + file.And(file.UpdatedAt(*token.Time), file.IDLT(token.ID)), + ) + }, + false: func(token *PageToken) predicate.File { + return file.Or( + file.UpdatedAtGT(*token.Time), + file.And(file.UpdatedAt(*token.Time), file.IDGT(token.ID)), + ) + }, + }, + file.FieldID: { + true: func(token *PageToken) predicate.File { + return file.IDLT(token.ID) + }, + false: func(token *PageToken) predicate.File { + return file.IDGT(token.ID) + }, + }, +} + +func getFileCursorQuery(args *ListFileParameters, token *PageToken, query []*ent.FileQuery) []*ent.FileQuery { + o := &sql.OrderTermOptions{} + getOrderTerm(args.Order)(o) + + predicates, ok := fileCursorQuery[args.OrderBy] + if !ok { + predicates = fileCursorQuery[file.FieldID] + } + + // If all folder is already listed in previous page, only query for files. + if token != nil && token.StartWithFile && !args.MixedType { + query = query[1:2] + } + + // Mixing folders and files with one query + if args.MixedType { + query = query[2:] + } else if args.FolderOnly { + query = query[0:1] + } + + if token != nil { + query[0].Where(predicates[o.Desc](token)) + } + return query +} + +// getFileNextPageToken returns the next page token for the given last file. +func getFileNextPageToken(hasher hashid.Encoder, last *ent.File, args *ListFileParameters, nextStartWithFile bool) (string, error) { + token := &PageToken{ + ID: last.ID, + StartWithFile: nextStartWithFile, + } + + switch args.OrderBy { + case file.FieldName: + token.String = last.Name + case file.FieldSize: + token.Int = int(last.Size) + case file.FieldUpdatedAt: + token.Time = &last.UpdatedAt + } + + return token.Encode(hasher, hashid.EncodeFileID) +} diff --git a/inventory/group.go b/inventory/group.go new file mode 100644 index 00000000..c40b6e89 --- /dev/null +++ b/inventory/group.go @@ -0,0 +1,170 @@ +package inventory + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/group" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" +) + +type ( + // Ctx keys for eager loading options. + LoadGroupPolicy struct{} +) + +const ( + AnonymousGroupID = 3 +) + +type ( + GroupClient interface { + TxOperator + // AnonymousGroup returns the anonymous group. + AnonymousGroup(ctx context.Context) (*ent.Group, error) + // ListAll returns all groups. + ListAll(ctx context.Context) ([]*ent.Group, error) + // GetByID returns the group by id. + GetByID(ctx context.Context, id int) (*ent.Group, error) + // ListGroups returns a list of groups with pagination. + ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error) + // CountUsers returns the number of users in the group. + CountUsers(ctx context.Context, id int) (int, error) + // Upsert upserts a group. + Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error) + // Delete deletes a group. + Delete(ctx context.Context, id int) error + } + ListGroupParameters struct { + *PaginationArgs + } + ListGroupResult struct { + *PaginationResults + Groups []*ent.Group + } +) + +func NewGroupClient(client *ent.Client, dbType conf.DBType, cache cache.Driver) GroupClient { + return &groupClient{client: client, maxSQlParam: sqlParamLimit(dbType), cache: cache} +} + +type groupClient struct { + client *ent.Client + cache cache.Driver + maxSQlParam int +} + +func (c *groupClient) SetClient(newClient *ent.Client) TxOperator { + return &groupClient{client: newClient, maxSQlParam: c.maxSQlParam, cache: c.cache} +} + +func (c *groupClient) GetClient() *ent.Client { + return c.client +} + +func (c *groupClient) CountUsers(ctx context.Context, id int) (int, error) { + return c.client.Group.Query().Where(group.ID(id)).QueryUsers().Count(ctx) +} + +func (c *groupClient) AnonymousGroup(ctx context.Context) (*ent.Group, error) { + return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(AnonymousGroupID))).First(ctx) +} + +func (c *groupClient) ListAll(ctx context.Context) ([]*ent.Group, error) { + return withGroupEagerLoading(ctx, c.client.Group.Query()).All(ctx) +} + +func (c *groupClient) Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error) { + + if group.ID == 0 { + return c.client.Group.Create(). + SetName(group.Name). + SetMaxStorage(group.MaxStorage). + SetSpeedLimit(group.SpeedLimit). + SetPermissions(group.Permissions). + SetSettings(group.Settings). + SetStoragePoliciesID(group.Edges.StoragePolicies.ID). + Save(ctx) + } + + res, err := c.client.Group.UpdateOne(group). + SetName(group.Name). + SetMaxStorage(group.MaxStorage). + SetSpeedLimit(group.SpeedLimit). + SetPermissions(group.Permissions). + SetSettings(group.Settings). + ClearStoragePolicies(). + SetStoragePoliciesID(group.Edges.StoragePolicies.ID). + Save(ctx) + if err != nil { + return nil, err + } + + return res, nil +} + +func (c *groupClient) Delete(ctx context.Context, id int) error { + if err := c.client.Group.DeleteOneID(id).Exec(ctx); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + + return nil +} + +func (c *groupClient) ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error) { + query := withGroupEagerLoading(ctx, c.client.Group.Query()) + pageSize := capPageSize(c.maxSQlParam, args.PageSize, 10) + queryWithoutOrder := query.Clone() + query.Order(getGroupOrderOption(args)...) + + // Count total items + total, err := queryWithoutOrder.Clone(). + Count(ctx) + if err != nil { + return nil, err + } + + groups, err := query.Clone(). + Limit(pageSize). + Offset(args.Page * pageSize). + All(ctx) + if err != nil { + return nil, err + } + + return &ListGroupResult{ + Groups: groups, + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: pageSize, + }, + }, nil +} + +func (c *groupClient) GetByID(ctx context.Context, id int) (*ent.Group, error) { + return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(id))).First(ctx) +} + +func getGroupOrderOption(args *ListGroupParameters) []group.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case group.FieldName: + return []group.OrderOption{group.ByName(orderTerm), group.ByID(orderTerm)} + case group.FieldMaxStorage: + return []group.OrderOption{group.ByMaxStorage(orderTerm), group.ByID(orderTerm)} + default: + return []group.OrderOption{group.ByID(orderTerm)} + } +} + +func withGroupEagerLoading(ctx context.Context, q *ent.GroupQuery) *ent.GroupQuery { + if _, ok := ctx.Value(LoadGroupPolicy{}).(bool); ok { + q.WithStoragePolicies(func(spq *ent.StoragePolicyQuery) { + withStoragePolicyEagerLoading(ctx, spq) + }) + } + return q +} diff --git a/inventory/node.go b/inventory/node.go new file mode 100644 index 00000000..1a005355 --- /dev/null +++ b/inventory/node.go @@ -0,0 +1,156 @@ +package inventory + +import ( + "context" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/node" +) + +type ( + LoadNodeStoragePolicy struct{} + NodeClient interface { + TxOperator + // ListActiveNodes returns the active nodes. + ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error) + // ListNodes returns the nodes with pagination. + ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error) + // GetNodeById returns the node by id. + GetNodeById(ctx context.Context, id int) (*ent.Node, error) + // GetNodeByIds returns the nodes by ids. + GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error) + // Upsert upserts a node. + Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error) + // Delete deletes a node. + Delete(ctx context.Context, id int) error + } + ListNodeParameters struct { + *PaginationArgs + Status node.Status + } + ListNodeResult struct { + *PaginationResults + Nodes []*ent.Node + } +) + +func NewNodeClient(client *ent.Client) NodeClient { + return &nodeClient{ + client: client, + } +} + +type nodeClient struct { + client *ent.Client +} + +func (c *nodeClient) SetClient(newClient *ent.Client) TxOperator { + return &nodeClient{client: newClient} +} + +func (c *nodeClient) GetClient() *ent.Client { + return c.client +} + +func (c *nodeClient) ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error) { + stm := c.client.Node.Query().Where(node.StatusEQ(node.StatusActive)) + if len(subset) > 0 { + stm = stm.Where(node.IDIn(subset...)) + } + return stm.All(ctx) +} + +func (c *nodeClient) GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error) { + return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDIn(ids...))).All(ctx) +} + +func (c *nodeClient) GetNodeById(ctx context.Context, id int) (*ent.Node, error) { + return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDEQ(id))).First(ctx) +} + +func (c *nodeClient) Delete(ctx context.Context, id int) error { + return c.client.Node.DeleteOneID(id).Exec(ctx) +} + +func (c *nodeClient) ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error) { + query := c.client.Node.Query() + if string(args.Status) != "" { + query = query.Where(node.StatusEQ(args.Status)) + } + query.Order(getNodeOrderOption(args)...) + + // Count total items + total, err := query.Clone(). + Count(ctx) + if err != nil { + return nil, err + } + + nodes, err := withNodeEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, err + } + + return &ListNodeResult{ + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: args.PageSize, + }, + Nodes: nodes, + }, nil +} + +func (c *nodeClient) Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error) { + if n.ID == 0 { + return c.client.Node.Create(). + SetName(n.Name). + SetServer(n.Server). + SetSlaveKey(n.SlaveKey). + SetStatus(n.Status). + SetType(node.TypeSlave). + SetSettings(n.Settings). + SetCapabilities(n.Capabilities). + SetWeight(n.Weight). + Save(ctx) + } + + res, err := c.client.Node.UpdateOne(n). + SetName(n.Name). + SetServer(n.Server). + SetSlaveKey(n.SlaveKey). + SetStatus(n.Status). + SetSettings(n.Settings). + SetCapabilities(n.Capabilities). + SetWeight(n.Weight). + Save(ctx) + if err != nil { + return nil, err + } + + return res, nil +} + +func getNodeOrderOption(args *ListNodeParameters) []node.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case node.FieldName: + return []node.OrderOption{node.ByName(orderTerm), node.ByID(orderTerm)} + case node.FieldWeight: + return []node.OrderOption{node.ByWeight(orderTerm), node.ByID(orderTerm)} + case node.FieldUpdatedAt: + return []node.OrderOption{node.ByUpdatedAt(orderTerm), node.ByID(orderTerm)} + default: + return []node.OrderOption{node.ByID(orderTerm)} + } +} + +func withNodeEagerLoading(ctx context.Context, query *ent.NodeQuery) *ent.NodeQuery { + if _, ok := ctx.Value(LoadNodeStoragePolicy{}).(bool); ok { + query = query.WithStoragePolicy(func(gq *ent.StoragePolicyQuery) { + withStoragePolicyEagerLoading(ctx, gq) + }) + } + + return query +} diff --git a/inventory/policy.go b/inventory/policy.go new file mode 100644 index 00000000..2343cf6e --- /dev/null +++ b/inventory/policy.go @@ -0,0 +1,243 @@ +package inventory + +import ( + "context" + "encoding/gob" + "fmt" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/storagepolicy" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" +) + +const ( + // StoragePolicyCacheKey is the cache key of storage policy. + StoragePolicyCacheKey = "storage_policy_" +) + +func init() { + gob.Register(ent.StoragePolicy{}) + gob.Register([]ent.StoragePolicy{}) +} + +type ( + LoadStoragePolicyGroup struct{} + SkipStoragePolicyCache struct{} + + StoragePolicyClient interface { + // GetByGroup returns the storage policies of the group. + GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error) + // GetPolicyByID returns the storage policy by id. + GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error) + // UpdateAccessKey updates the access key of the storage policy. It also clear related cache in KV. + UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error + // ListPolicyByType returns the storage policies by type. + ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error) + // ListPolicies returns the storage policies with pagination. + ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error) + // Upsert upserts the storage policy. + Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error) + // Delete deletes the storage policy. + Delete(ctx context.Context, policy *ent.StoragePolicy) error + } + + ListPolicyParameters struct { + *PaginationArgs + Type types.PolicyType + } + + ListPolicyResult struct { + *PaginationResults + Policies []*ent.StoragePolicy + } +) + +// NewStoragePolicyClient returns a new StoragePolicyClient. +func NewStoragePolicyClient(client *ent.Client, cache cache.Driver) StoragePolicyClient { + return &storagePolicyClient{client: client, cache: cache} +} + +type storagePolicyClient struct { + client *ent.Client + cache cache.Driver +} + +func (c *storagePolicyClient) Delete(ctx context.Context, policy *ent.StoragePolicy) error { + if err := c.client.StoragePolicy.DeleteOne(policy).Exec(ctx); err != nil { + return fmt.Errorf("failed to delete storage policy: %w", err) + } + + // Clear cache + if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil { + return fmt.Errorf("failed to clear storage policy cache: %w", err) + } + return nil +} + +func (c *storagePolicyClient) Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error) { + var nodeId *int + if policy.NodeID != 0 { + nodeId = &policy.NodeID + } + if policy.ID == 0 { + p, err := c.client.StoragePolicy. + Create(). + SetName(policy.Name). + SetType(policy.Type). + SetServer(policy.Server). + SetBucketName(policy.BucketName). + SetIsPrivate(policy.IsPrivate). + SetAccessKey(policy.AccessKey). + SetSecretKey(policy.SecretKey). + SetMaxSize(policy.MaxSize). + SetDirNameRule(policy.DirNameRule). + SetFileNameRule(policy.FileNameRule). + SetSettings(policy.Settings). + SetNillableNodeID(nodeId). + Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create storage policy: %w", err) + } + return p, nil + } + + updateQuery := c.client.StoragePolicy.UpdateOne(policy). + SetName(policy.Name). + SetType(policy.Type). + SetServer(policy.Server). + SetBucketName(policy.BucketName). + SetIsPrivate(policy.IsPrivate). + SetSecretKey(policy.SecretKey). + SetMaxSize(policy.MaxSize). + SetDirNameRule(policy.DirNameRule). + SetFileNameRule(policy.FileNameRule). + SetSettings(policy.Settings). + SetNillableNodeID(nodeId) + + if policy.Type != types.PolicyTypeOd { + updateQuery.SetAccessKey(policy.AccessKey) + } + + p, err := updateQuery.Save(ctx) + + // Clear cache + if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil { + return nil, fmt.Errorf("failed to clear storage policy cache: %w", err) + } + if err != nil { + return nil, fmt.Errorf("failed to update storage policy: %w", err) + } + return p, nil + +} + +func (c *storagePolicyClient) GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error) { + val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool) + skipCache = skipCache && val + + res, err := withStoragePolicyEagerLoading(ctx, c.client.Group.QueryStoragePolicies(group)).WithNode().First(ctx) + if err != nil { + return nil, fmt.Errorf("get storage policies: %w", err) + } + + return res, nil +} + +// GetPolicyByID returns the storage policy by id. +func (c *storagePolicyClient) GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error) { + val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool) + skipCache = skipCache && val + + // Try to read from cache + if c.cache != nil && !skipCache { + if res, ok := c.cache.Get(StoragePolicyCacheKey + strconv.Itoa(id)); ok { + cached := res.(ent.StoragePolicy) + return &cached, nil + } + } + + res, err := withStoragePolicyEagerLoading(ctx, c.client.StoragePolicy.Query().Where(storagepolicy.ID(id))).WithNode().First(ctx) + if err != nil { + return nil, fmt.Errorf("get storage policy: %w", err) + } + + // Write to cache + if c.cache != nil && !skipCache { + _ = c.cache.Set(StoragePolicyCacheKey+strconv.Itoa(id), *res, -1) + } + + return res, nil +} + +func (c *storagePolicyClient) ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error) { + policies, err := c.client.StoragePolicy.Query().Where(storagepolicy.TypeEQ(string(t))).All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list storage policies: %w", err) + } + + return policies, nil +} + +func (c *storagePolicyClient) UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error { + _, err := c.client.StoragePolicy.UpdateOne(policy).SetAccessKey(token).Save(ctx) + if err != nil { + return fmt.Errorf("faield to update access key in DB: %w", err) + } + + // Clear cache + if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil { + return fmt.Errorf("failed to clear storage policy cache: %w", err) + } + + return nil +} + +func (c *storagePolicyClient) ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error) { + query := c.client.StoragePolicy.Query().WithNode() + if args.Type != "" { + query = query.Where(storagepolicy.TypeEQ(string(args.Type))) + } + query.Order(getStoragePolicyOrderOption(args)...) + + // Count total items + total, err := query.Clone(). + Count(ctx) + if err != nil { + return nil, err + } + + policies, err := withStoragePolicyEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, err + } + + return &ListPolicyResult{ + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: args.PageSize, + }, + Policies: policies, + }, nil +} + +func getStoragePolicyOrderOption(args *ListPolicyParameters) []storagepolicy.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case storagepolicy.FieldUpdatedAt: + return []storagepolicy.OrderOption{storagepolicy.ByUpdatedAt(orderTerm), storagepolicy.ByID(orderTerm)} + default: + return []storagepolicy.OrderOption{storagepolicy.ByID(orderTerm)} + } +} + +func withStoragePolicyEagerLoading(ctx context.Context, query *ent.StoragePolicyQuery) *ent.StoragePolicyQuery { + if _, ok := ctx.Value(LoadStoragePolicyGroup{}).(bool); ok { + query = query.WithGroups(func(gq *ent.GroupQuery) { + withGroupEagerLoading(ctx, gq) + }) + } + return query +} diff --git a/inventory/setting.go b/inventory/setting.go new file mode 100644 index 00000000..a63f7a91 --- /dev/null +++ b/inventory/setting.go @@ -0,0 +1,248 @@ +package inventory + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" +) + +type ( + SettingClient interface { + TxOperator + // Get gets a setting value from DB, returns error if setting cannot be found. + Get(ctx context.Context, name string) (string, error) + // Set sets a setting value to DB. + Set(ctx context.Context, settings map[string]string) error + // Gets gets multiple setting values from DB, returns error if any setting cannot be found. + Gets(ctx context.Context, names []string) (map[string]string, error) + } +) + +// NewSettingClient creates a new SettingClient +func NewSettingClient(client *ent.Client, kv cache.Driver) SettingClient { + return &settingClient{client: client, kv: kv} +} + +type settingClient struct { + client *ent.Client + kv cache.Driver +} + +// SetClient sets the client for the setting client +func (c *settingClient) SetClient(newClient *ent.Client) TxOperator { + return &settingClient{client: newClient, kv: c.kv} +} + +// GetClient gets the client for the setting client +func (c *settingClient) GetClient() *ent.Client { + return c.client +} + +func (c *settingClient) Get(ctx context.Context, name string) (string, error) { + s, err := c.client.Setting.Query().Where(setting.Name(name)).Only(ctx) + if err != nil { + return "", fmt.Errorf("failed to query setting %q from DB: %w", name, err) + } + + return s.Value, nil +} + +func (c *settingClient) Gets(ctx context.Context, names []string) (map[string]string, error) { + settings := make(map[string]string) + res, err := c.client.Setting.Query().Where(setting.NameIn(names...)).All(ctx) + if err != nil { + return nil, err + } + + for _, s := range res { + settings[s.Name] = s.Value + } + + return settings, nil +} + +func (c *settingClient) Set(ctx context.Context, settings map[string]string) error { + for k, v := range settings { + if err := c.client.Setting.Update().Where(setting.Name(k)).SetValue(v).Exec(ctx); err != nil { + return fmt.Errorf("failed to create setting %q: %w", k, err) + } + + } + + return nil +} + +var DefaultSettings = map[string]string{ + "siteURL": `http://localhost:5212`, + "siteName": `Cloudreve`, + "siteDes": "Cloudreve", + "siteID": uuid.Must(uuid.NewV4()).String(), + "siteTitle": "Cloud storage for everyone", + "siteScript": "", + "pwa_small_icon": "/static/img/favicon.ico", + "pwa_medium_icon": "/static/img/logo192.png", + "pwa_large_icon": "/static/img/logo512.png", + "pwa_display": "standalone", + "pwa_theme_color": "#000000", + "pwa_background_color": "#ffffff", + "register_enabled": `1`, + "default_group": `2`, + "fromName": `Cloudreve`, + "mail_keepalive": `30`, + "fromAdress": `no-reply@cloudreve.org`, + "smtpHost": `smtp.cloudreve.com`, + "smtpPort": `25`, + "replyTo": `support@cloudreve.org`, + "smtpUser": `smtp.cloudreve.com`, + "smtpPass": ``, + "smtpEncryption": `0`, + "ban_time": `604800`, + "maxEditSize": `52428800`, + "archive_timeout": `600`, + "upload_session_timeout": `86400`, + "slave_api_timeout": `60`, + "folder_props_timeout": `300`, + "chunk_retries": `5`, + "use_temp_chunk_buffer": `1`, + "login_captcha": `0`, + "reg_captcha": `0`, + "email_active": `0`, + "forget_captcha": `0`, + "gravatar_server": `https://www.gravatar.com/`, + "defaultTheme": `#1976d2`, + "theme_options": `{"#1976d2":{"light":{"palette":{"primary":{"main":"#1976d2","light":"#42a5f5","dark":"#1565c0"},"secondary":{"main":"#9c27b0","light":"#ba68c8","dark":"#7b1fa2"}}},"dark":{"palette":{"primary":{"main":"#90caf9","light":"#e3f2fd","dark":"#42a5f5"},"secondary":{"main":"#ce93d8","light":"#f3e5f5","dark":"#ab47bc"}}}},"#3f51b5":{"light":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"dark":{"palette":{"primary":{"main":"#9fa8da"},"secondary":{"main":"#ff4081"}}}}}`, + "max_parallel_transfer": `4`, + "secret_key": util.RandStringRunes(256), + "temp_path": "temp", + "avatar_path": "avatar", + "avatar_size": "4194304", + "avatar_size_l": "200", + "cron_garbage_collect": "@every 30m", + "cron_entity_collect": "@every 15m", + "cron_trash_bin_collect": "@every 33m", + "cron_oauth_cred_refresh": "@every 230h", + "authn_enabled": "1", + "captcha_type": "normal", + "captcha_height": "60", + "captcha_width": "240", + "captcha_mode": "3", + "captcha_ComplexOfNoiseText": "0", + "captcha_ComplexOfNoiseDot": "0", + "captcha_IsShowHollowLine": "0", + "captcha_IsShowNoiseDot": "1", + "captcha_IsShowNoiseText": "0", + "captcha_IsShowSlimeLine": "1", + "captcha_IsShowSineLine": "0", + "captcha_CaptchaLen": "6", + "captcha_ReCaptchaKey": "defaultKey", + "captcha_ReCaptchaSecret": "defaultSecret", + "captcha_turnstile_site_key": "", + "captcha_turnstile_site_secret": "", + "thumb_width": "400", + "thumb_height": "300", + "thumb_entity_suffix": "._thumb", + "thumb_slave_sidecar_suffix": "._thumb_sidecar", + "thumb_encode_method": "png", + "thumb_gc_after_gen": "0", + "thumb_encode_quality": "95", + "thumb_builtin_enabled": "1", + "thumb_builtin_max_size": "78643200", // 75 MB + "thumb_vips_max_size": "78643200", // 75 MB + "thumb_vips_enabled": "0", + "thumb_vips_exts": "3fr,ari,arw,bay,braw,crw,cr2,cr3,cap,data,dcs,dcr,dng,drf,eip,erf,fff,gpr,iiq,k25,kdc,mdc,mef,mos,mrw,nef,nrw,obm,orf,pef,ptx,pxn,r3d,raf,raw,rwl,rw2,rwz,sr2,srf,srw,tif,x3f,csv,mat,img,hdr,pbm,pgm,ppm,pfm,pnm,svg,svgz,j2k,jp2,jpt,j2c,jpc,gif,png,jpg,jpeg,jpe,webp,tif,tiff,fits,fit,fts,exr,jxl,pdf,heic,heif,avif,svs,vms,vmu,ndpi,scn,mrxs,svslide,bif,raw", + "thumb_ffmpeg_enabled": "0", + "thumb_vips_path": "vips", + "thumb_ffmpeg_path": "ffmpeg", + "thumb_ffmpeg_max_size": "10737418240", // 10 GB + "thumb_ffmpeg_exts": "3g2,3gp,asf,asx,avi,divx,flv,m2ts,m2v,m4v,mkv,mov,mp4,mpeg,mpg,mts,mxf,ogv,rm,swf,webm,wmv", + "thumb_ffmpeg_seek": "00:00:01.00", + "thumb_libreoffice_path": "soffice", + "thumb_libreoffice_max_size": "78643200", // 75 MB + "thumb_libreoffice_enabled": "0", + "thumb_libreoffice_exts": "txt,pdf,md,ods,ots,fods,uos,xlsx,xml,xls,xlt,dif,dbf,html,slk,csv,xlsm,docx,dotx,doc,dot,rtf,xlsm,xlst,xls,xlw,xlc,xlt,pptx,ppsx,potx,pomx,ppt,pps,ppm,pot,pom", + "thumb_music_cover_enabled": "1", + "thumb_music_cover_exts": "mp3,m4a,ogg,flac", + "thumb_music_cover_max_size": "1073741824", // 1 GB + "phone_required": "false", + "phone_enabled": "false", + "show_app_promotion": "1", + "public_resource_maxage": "86400", + "viewer_session_timeout": "36000", + "hash_id_salt": util.RandStringRunes(64), + "mail_activation_template": `[{"language":"en-US","title":"Activate your account","body":"
                                                           
"},{"language":"zh-CN","title":"激活你的账号","body":"
                                                           
"}]`, + "mail_reset_template": `[{"language":"en-US","title":"Reset your password","body":"
                                                           
"},{"language":"zh-CN","title":"重设密码","body":"
                                                           
"}]`, + "access_token_ttl": "3600", + "refresh_token_ttl": "1209600", // 2 weeks + "use_cursor_pagination": "1", + "max_page_size": "2000", + "max_recursive_searched_folder": "65535", + "max_batched_file": "3000", + "queue_media_meta_worker_num": "30", + "queue_media_meta_max_execution": "600", + "queue_media_meta_backoff_factor": "2", + "queue_media_meta_backoff_max_duration": "60", + "queue_media_meta_max_retry": "1", + "queue_media_meta_retry_delay": "0", + "queue_thumb_worker_num": "15", + "queue_thumb_max_execution": "300", + "queue_thumb_backoff_factor": "2", + "queue_thumb_backoff_max_duration": "60", + "queue_thumb_max_retry": "0", + "queue_thumb_retry_delay": "0", + "queue_recycle_worker_num": "5", + "queue_recycle_max_execution": "900", + "queue_recycle_backoff_factor": "2", + "queue_recycle_backoff_max_duration": "60", + "queue_recycle_max_retry": "0", + "queue_recycle_retry_delay": "0", + "queue_io_intense_worker_num": "30", + "queue_io_intense_max_execution": "2592000", + "queue_io_intense_backoff_factor": "2", + "queue_io_intense_backoff_max_duration": "600", + "queue_io_intense_max_retry": "5", + "queue_io_intense_retry_delay": "0", + "queue_remote_download_worker_num": "5", + "queue_remote_download_max_execution": "864000", + "queue_remote_download_backoff_factor": "2", + "queue_remote_download_backoff_max_duration": "600", + "queue_remote_download_max_retry": "5", + "queue_remote_download_retry_delay": "0", + "entity_url_default_ttl": "3600", + "entity_url_cache_margin": "600", + "media_meta": "1", + "media_meta_exif": "1", + "media_meta_exif_size_local": "1073741824", + "media_meta_exif_size_remote": "104857600", + "media_meta_exif_brute_force": "1", + "media_meta_music": "1", + "media_meta_music_size_local": "1073741824", + "media_exif_music_size_remote": "1073741824", + "media_meta_ffprobe": "0", + "media_meta_ffprobe_path": "ffprobe", + "media_meta_ffprobe_size_local": "0", + "media_meta_ffprobe_size_remote": "0", + "site_logo": "/static/img/logo.svg", + "site_logo_light": "/static/img/logo_light.svg", + "tos_url": "https://cloudreve.org/privacy-policy", + "privacy_policy_url": "https://cloudreve.org/privacy-policy", + "explorer_icons": `[{"exts":["mp3","flac","ape","wav","acc","ogg","m4a"],"icon":"audio","color":"#651fff"},{"exts":["mp4","flv","avi","wmv","mkv","rm","rmvb","mov","ogv"],"icon":"video","color":"#d50000"},{"exts":["bmp","iff","png","gif","jpg","jpeg","psd","svg","webp","heif","heic","tiff","avif"],"icon":"image","color":"#d32f2f"},{"exts":["3fr","ari","arw","bay","braw","crw","cr2","cr3","cap","dcs","dcr","dng","drf","eip","erf","fff","gpr","iiq","k25","kdc","mdc","mef","mos","mrw","nef","nrw","obm","orf","pef","ptx","pxn","r3d","raf","raw","rwl","rw2","rwz","sr2","srf","srw","tif","x3f"],"icon":"raw","color":"#d32f2f"},{"exts":["pdf"],"color":"#f44336","icon":"pdf"},{"exts":["doc","docx"],"color":"#538ce5","icon":"word"},{"exts":["ppt","pptx"],"color":"#EF633F","icon":"ppt"},{"exts":["xls","xlsx","csv"],"color":"#4caf50","icon":"excel"},{"exts":["txt","html"],"color":"#607d8b","icon":"text"},{"exts":["torrent"],"color":"#5c6bc0","icon":"torrent"},{"exts":["zip","gz","xz","tar","rar","7z","bz2","z"],"color":"#f9a825","icon":"zip"},{"exts":["exe","msi"],"color":"#1a237e","icon":"exe"},{"exts":["apk"],"color":"#8bc34a","icon":"android"},{"exts":["go"],"color":"#16b3da","icon":"go"},{"exts":["py"],"color":"#3776ab","icon":"python"},{"exts":["c"],"color":"#a4c639","icon":"c"},{"exts":["cpp"],"color":"#f34b7d","icon":"cpp"},{"exts":["js","jsx"],"color":"#f4d003","icon":"js"},{"exts":["epub"],"color":"#81b315","icon":"book"},{"exts":["rs"],"color":"#000","color_dark":"#fff","icon":"rust"},{"exts":["drawio"],"color":"#F08705","icon":"flowchart"},{"exts":["dwb"],"color":"#F08705","icon":"whiteboard"},{"exts":["md"],"color":"#383838","color_dark":"#cbcbcb","icon":"markdown"}]`, + "explorer_category_image_query": "type=file&case_folding&use_or&name=*.bmp&name=*.iff&name=*.png&name=*.gif&name=*.jpg&name=*.jpeg&name=*.psd&name=*.svg&name=*.webp&name=*.heif&name=*.heic&name=*.tiff&name=*.avif&name=*.3fr&name=*.ari&name=*.arw&name=*.bay&name=*.braw&name=*.crw&name=*.cr2&name=*.cr3&name=*.cap&name=*.dcs&name=*.dcr&name=*.dng&name=*.drf&name=*.eip&name=*.erf&name=*.fff&name=*.gpr&name=*.iiq&name=*.k25&name=*.kdc&name=*.mdc&name=*.mef&name=*.mos&name=*.mrw&name=*.nef&name=*.nrw&name=*.obm&name=*.orf&name=*.pef&name=*.ptx&name=*.pxn&name=*.r3d&name=*.raf&name=*.raw&name=*.rwl&name=*.rw2&name=*.rwz&name=*.sr2&name=*.srf&name=*.srw&name=*.tif&name=*.x3f", + "explorer_category_video_query": "type=file&case_folding&use_or&name=*.mp4&name=*.flv&name=*.avi&name=*.wmv&name=*.mkv&name=*.rm&name=*.rmvb&name=*.mov&name=*.ogv", + "explorer_category_audio_query": "type=file&case_folding&use_or&name=*.mp3&name=*.flac&name=*.ape&name=*.wav&name=*.acc&name=*.ogg&name=*.m4a", + "explorer_category_document_query": "type=file&case_folding&use_or&name=*.pdf&name=*.doc&name=*.docx&name=*.ppt&name=*.pptx&name=*.xls&name=*.xlsx&name=*.csv&name=*.txt&name=*.md&name=*.pub", + "use_sse_for_search": "0", + "emojis": `{"😀":["😀","😃","😄","😁","😆","😅","🤣","😂","🙂","🙃","🫠","😉","😊","😇","🥰","😍","🤩","😘","😗","😚","😙","🥲","😋","😛","😜","🤪","😝","🤑","🤗","🤭","🫢","🫣","🤫","🤔","🫡","🤐","🤨","😐","😑","😶","😶‍🌫️","😏","😒","🙄","😬","😮‍💨","🤥","😌","😔","😪","🤤","😴","😷","🤒","🤕","🤢","🤮","🤧","🥵","🥶","🥴","😵","😵‍💫","🤯","🤠","🥳","🥸","😎","🤓","🧐","😕","🫤","😟","🙁","😮","😯","😲","😳","🥺","🥹","😦","😧","😨","😰","😥","😢","😭","😱","😖","😣","😞","😓","😩","😫","🥱","😤","😡","😠","🤬","😈","👿","💀","☠️","💩","🤡","👹","👺","👻","👽","👾","🤖","😺","😸","😹","😻","😼","😽","🙀","😿","😾","🙈","🙉","🙊","💋","💌","💘","💝","💖","💗","💓","💞","💕","💟","💔","❤️‍🔥","❤️‍🩹","❤️","🧡","💛","💚","💙","💜","🤎","🖤","🤍","💯","💢","💥","💫","💦","💨","🕳️","💣","💬","👁️‍🗨️","🗨️","🗯️","💭","💤"],"👋":["👋","🤚","🖐️","✋","🖖","🫱","🫲","🫳","🫴","👌","🤌","🤏","✌️","🤞","🫰","🤟","🤘","🤙","👈","👉","👆","🖕","👇","☝️","🫵","👍","👎","✊","👊","🤛","🤜","👏","🙌","🫶","👐","🤲","🤝","🙏","✍️","💅","🤳","💪","🦾","🦿","🦵","🦶","👂","🦻","👃","🧠","🫀","🫁","🦷","🦴","👀","👁️","👅","👄","🫦","👶","🧒","👦","👧","🧑","👱","👨","🧔","🧔‍♂️","🧔‍♀️","👨‍🦰","👨‍🦱","👨‍🦳","👨‍🦲","👩","👩‍🦰","🧑‍🦰","👩‍🦱","🧑‍🦱","👩‍🦳","🧑‍🦳","👩‍🦲","🧑‍🦲","👱‍♀️","👱‍♂️","🧓","👴","👵","🙍","🙍‍♂️","🙍‍♀️","🙎","🙎‍♂️","🙎‍♀️","🙅","🙅‍♂️","🙅‍♀️","🙆","🙆‍♂️","🙆‍♀️","💁","💁‍♂️","💁‍♀️","🙋","🙋‍♂️","🙋‍♀️","🧏","🧏‍♂️","🧏‍♀️","🙇","🙇‍♂️","🙇‍♀️","🤦","🤦‍♂️","🤦‍♀️","🤷","🤷‍♂️","🤷‍♀️","🧑‍⚕️","👨‍⚕️","👩‍⚕️","🧑‍🎓","👨‍🎓","👩‍🎓","🧑‍🏫","👨‍🏫","👩‍🏫","🧑‍⚖️","👨‍⚖️","👩‍⚖️","🧑‍🌾","👨‍🌾","👩‍🌾","🧑‍🍳","👨‍🍳","👩‍🍳","🧑‍🔧","👨‍🔧","👩‍🔧","🧑‍🏭","👨‍🏭","👩‍🏭","🧑‍💼","👨‍💼","👩‍💼","🧑‍🔬","👨‍🔬","👩‍🔬","🧑‍💻","👨‍💻","👩‍💻","🧑‍🎤","👨‍🎤","👩‍🎤","🧑‍🎨","👨‍🎨","👩‍🎨","🧑‍✈️","👨‍✈️","👩‍✈️","🧑‍🚀","👨‍🚀","👩‍🚀","🧑‍🚒","👨‍🚒","👩‍🚒","👮","👮‍♂️","👮‍♀️","🕵️","🕵️‍♂️","🕵️‍♀️","💂","💂‍♂️","💂‍♀️","🥷","👷","👷‍♂️","👷‍♀️","🫅","🤴","👸","👳","👳‍♂️","👳‍♀️","👲","🧕","🤵","🤵‍♂️","🤵‍♀️","👰","👰‍♂️","👰‍♀️","🤰","🫃","🫄","🤱","👩‍🍼","👨‍🍼","🧑‍🍼","👼","🎅","🤶","🧑‍🎄","🦸","🦸‍♂️","🦸‍♀️","🦹","🦹‍♂️","🦹‍♀️","🧙","🧙‍♂️","🧙‍♀️","🧚","🧚‍♂️","🧚‍♀️","🧛","🧛‍♂️","🧛‍♀️","🧜","🧜‍♂️","🧜‍♀️","🧝","🧝‍♂️","🧝‍♀️","🧞","🧞‍♂️","🧞‍♀️","🧟","🧟‍♂️","🧟‍♀️","🧌","💆","💆‍♂️","💆‍♀️","💇","💇‍♂️","💇‍♀️","🚶","🚶‍♂️","🚶‍♀️","🧍","🧍‍♂️","🧍‍♀️","🧎","🧎‍♂️","🧎‍♀️","🧑‍🦯","👨‍🦯","👩‍🦯","🧑‍🦼","👨‍🦼","👩‍🦼","🧑‍🦽","👨‍🦽","👩‍🦽","🏃","🏃‍♂️","🏃‍♀️","💃","🕺","🕴️","👯","👯‍♂️","👯‍♀️","🧖","🧖‍♂️","🧖‍♀️","🧗","🧗‍♂️","🧗‍♀️","🤺","🏇","⛷️","🏂","🏌️","🏌️‍♂️","🏌️‍♀️","🏄","🏄‍♂️","🏄‍♀️","🚣","🚣‍♂️","🚣‍♀️","🏊","🏊‍♂️","🏊‍♀️","⛹️","⛹️‍♂️","⛹️‍♀️","🏋️","🏋️‍♂️","🏋️‍♀️","🚴","🚴‍♂️","🚴‍♀️","🚵","🚵‍♂️","🚵‍♀️","🤸","🤸‍♂️","🤸‍♀️","🤼","🤼‍♂️","🤼‍♀️","🤽","🤽‍♂️","🤽‍♀️","🤾","🤾‍♂️","🤾‍♀️","🤹","🤹‍♂️","🤹‍♀️","🧘","🧘‍♂️","🧘‍♀️","🛀","🛌","🧑‍🤝‍🧑","👭","👫","👬","💏","👩‍❤️‍💋‍👨","👨‍❤️‍💋‍👨","👩‍❤️‍💋‍👩","💑","👩‍❤️‍👨","👨‍❤️‍👨","👩‍❤️‍👩","👪","👨‍👩‍👦","👨‍👩‍👧","👨‍👩‍👧‍👦","👨‍👩‍👦‍👦","👨‍👩‍👧‍👧","👨‍👨‍👦","👨‍👨‍👧","👨‍👨‍👧‍👦","👨‍👨‍👦‍👦","👨‍👨‍👧‍👧","👩‍👩‍👦","👩‍👩‍👧","👩‍👩‍👧‍👦","👩‍👩‍👦‍👦","👩‍👩‍👧‍👧","👨‍👦","👨‍👦‍👦","👨‍👧","👨‍👧‍👦","👨‍👧‍👧","👩‍👦","👩‍👦‍👦","👩‍👧","👩‍👧‍👦","👩‍👧‍👧","🗣️","👤","👥","🫂","👣","🦰","🦱","🦳","🦲"],"🐵":["🐵","🐒","🦍","🦧","🐶","🐕","🦮","🐕‍🦺","🐩","🐺","🦊","🦝","🐱","🐈","🐈‍⬛","🦁","🐯","🐅","🐆","🐴","🐎","🦄","🦓","🦌","🦬","🐮","🐂","🐃","🐄","🐷","🐖","🐗","🐽","🐏","🐑","🐐","🐪","🐫","🦙","🦒","🐘","🦣","🦏","🦛","🐭","🐁","🐀","🐹","🐰","🐇","🐿️","🦫","🦔","🦇","🐻","🐻‍❄️","🐨","🐼","🦥","🦦","🦨","🦘","🦡","🐾","🦃","🐔","🐓","🐣","🐤","🐥","🐦","🐧","🕊️","🦅","🦆","🦢","🦉","🦤","🪶","🦩","🦚","🦜","🐸","🐊","🐢","🦎","🐍","🐲","🐉","🦕","🦖","🐳","🐋","🐬","🦭","🐟","🐠","🐡","🦈","🐙","🐚","🪸","🐌","🦋","🐛","🐜","🐝","🪲","🐞","🦗","🪳","🕷️","🕸️","🦂","🦟","🪰","🪱","🦠","💐","🌸","💮","🪷","🏵️","🌹","🥀","🌺","🌻","🌼","🌷","🌱","🪴","🌲","🌳","🌴","🌵","🌾","🌿","☘️","🍀","🍁","🍂","🍃","🪹","🪺"],"🍇":["🍇","🍈","🍉","🍊","🍋","🍌","🍍","🥭","🍎","🍏","🍐","🍑","🍒","🍓","🫐","🥝","🍅","🫒","🥥","🥑","🍆","🥔","🥕","🌽","🌶️","🫑","🥒","🥬","🥦","🧄","🧅","🍄","🥜","🫘","🌰","🍞","🥐","🥖","🫓","🥨","🥯","🥞","🧇","🧀","🍖","🍗","🥩","🥓","🍔","🍟","🍕","🌭","🥪","🌮","🌯","🫔","🥙","🧆","🥚","🍳","🥘","🍲","🫕","🥣","🥗","🍿","🧈","🧂","🥫","🍱","🍘","🍙","🍚","🍛","🍜","🍝","🍠","🍢","🍣","🍤","🍥","🥮","🍡","🥟","🥠","🥡","🦀","🦞","🦐","🦑","🦪","🍦","🍧","🍨","🍩","🍪","🎂","🍰","🧁","🥧","🍫","🍬","🍭","🍮","🍯","🍼","🥛","☕","🫖","🍵","🍶","🍾","🍷","🍸","🍹","🍺","🍻","🥂","🥃","🫗","🥤","🧋","🧃","🧉","🧊","🥢","🍽️","🍴","🥄","🔪","🫙","🏺"],"🌍":["🌍","🌎","🌏","🌐","🗺️","🗾","🧭","🏔️","⛰️","🌋","🗻","🏕️","🏖️","🏜️","🏝️","🏞️","🏟️","🏛️","🏗️","🧱","🪨","🪵","🛖","🏘️","🏚️","🏠","🏡","🏢","🏣","🏤","🏥","🏦","🏨","🏩","🏪","🏫","🏬","🏭","🏯","🏰","💒","🗼","🗽","⛪","🕌","🛕","🕍","⛩️","🕋","⛲","⛺","🌁","🌃","🏙️","🌄","🌅","🌆","🌇","🌉","♨️","🎠","🛝","🎡","🎢","💈","🎪","🚂","🚃","🚄","🚅","🚆","🚇","🚈","🚉","🚊","🚝","🚞","🚋","🚌","🚍","🚎","🚐","🚑","🚒","🚓","🚔","🚕","🚖","🚗","🚘","🚙","🛻","🚚","🚛","🚜","🏎️","🏍️","🛵","🦽","🦼","🛺","🚲","🛴","🛹","🛼","🚏","🛣️","🛤️","🛢️","⛽","🛞","🚨","🚥","🚦","🛑","🚧","⚓","🛟","⛵","🛶","🚤","🛳️","⛴️","🛥️","🚢","✈️","🛩️","🛫","🛬","🪂","💺","🚁","🚟","🚠","🚡","🛰️","🚀","🛸","🛎️","🧳","⌛","⏳","⌚","⏰","⏱️","⏲️","🕰️","🕛","🕧","🕐","🕜","🕑","🕝","🕒","🕞","🕓","🕟","🕔","🕠","🕕","🕡","🕖","🕢","🕗","🕣","🕘","🕤","🕙","🕥","🕚","🕦","🌑","🌒","🌓","🌔","🌕","🌖","🌗","🌘","🌙","🌚","🌛","🌜","🌡️","☀️","🌝","🌞","🪐","⭐","🌟","🌠","🌌","☁️","⛅","⛈️","🌤️","🌥️","🌦️","🌧️","🌨️","🌩️","🌪️","🌫️","🌬️","🌀","🌈","🌂","☂️","☔","⛱️","⚡","❄️","☃️","⛄","☄️","🔥","💧","🌊"],"🎃":["🎃","🎄","🎆","🎇","🧨","✨","🎈","🎉","🎊","🎋","🎍","🎎","🎏","🎐","🎑","🧧","🎀","🎁","🎗️","🎟️","🎫","🎖️","🏆","🏅","🥇","🥈","🥉","⚽","⚾","🥎","🏀","🏐","🏈","🏉","🎾","🥏","🎳","🏏","🏑","🏒","🥍","🏓","🏸","🥊","🥋","🥅","⛳","⛸️","🎣","🤿","🎽","🎿","🛷","🥌","🎯","🪀","🪁","🎱","🔮","🪄","🧿","🪬","🎮","🕹️","🎰","🎲","🧩","🧸","🪅","🪩","🪆","♠️","♥️","♦️","♣️","♟️","🃏","🀄","🎴","🎭","🖼️","🎨","🧵","🪡","🧶","🪢"],"👓":["👓","🕶️","🥽","🥼","🦺","👔","👕","👖","🧣","🧤","🧥","🧦","👗","👘","🥻","🩱","🩲","🩳","👙","👚","👛","👜","👝","🛍️","🎒","🩴","👞","👟","🥾","🥿","👠","👡","🩰","👢","👑","👒","🎩","🎓","🧢","🪖","⛑️","📿","💄","💍","💎","🔇","🔈","🔉","🔊","📢","📣","📯","🔔","🔕","🎼","🎵","🎶","🎙️","🎚️","🎛️","🎤","🎧","📻","🎷","🪗","🎸","🎹","🎺","🎻","🪕","🥁","🪘","📱","📲","☎️","📞","📟","📠","🔋","🪫","🔌","💻","🖥️","🖨️","⌨️","🖱️","🖲️","💽","💾","💿","📀","🧮","🎥","🎞️","📽️","🎬","📺","📷","📸","📹","📼","🔍","🔎","🕯️","💡","🔦","🏮","🪔","📔","📕","📖","📗","📘","📙","📚","📓","📒","📃","📜","📄","📰","🗞️","📑","🔖","🏷️","💰","🪙","💴","💵","💶","💷","💸","💳","🧾","💹","✉️","📧","📨","📩","📤","📥","📦","📫","📪","📬","📭","📮","🗳️","✏️","✒️","🖋️","🖊️","🖌️","🖍️","📝","💼","📁","📂","🗂️","📅","📆","🗒️","🗓️","📇","📈","📉","📊","📋","📌","📍","📎","🖇️","📏","📐","✂️","🗃️","🗄️","🗑️","🔒","🔓","🔏","🔐","🔑","🗝️","🔨","🪓","⛏️","⚒️","🛠️","🗡️","⚔️","🔫","🪃","🏹","🛡️","🪚","🔧","🪛","🔩","⚙️","🗜️","⚖️","🦯","🔗","⛓️","🪝","🧰","🧲","🪜","⚗️","🧪","🧫","🧬","🔬","🔭","📡","💉","🩸","💊","🩹","🩼","🩺","🩻","🚪","🛗","🪞","🪟","🛏️","🛋️","🪑","🚽","🪠","🚿","🛁","🪤","🪒","🧴","🧷","🧹","🧺","🧻","🪣","🧼","🫧","🪥","🧽","🧯","🛒","🚬","⚰️","🪦","⚱️","🗿","🪧","🪪"],"🏧":["🏧","🚮","🚰","♿","🚹","🚺","🚻","🚼","🚾","🛂","🛃","🛄","🛅","⚠️","🚸","⛔","🚫","🚳","🚭","🚯","🚱","🚷","📵","🔞","☢️","☣️","⬆️","↗️","➡️","↘️","⬇️","↙️","⬅️","↖️","↕️","↔️","↩️","↪️","⤴️","⤵️","🔃","🔄","🔙","🔚","🔛","🔜","🔝","🛐","⚛️","🕉️","✡️","☸️","☯️","✝️","☦️","☪️","☮️","🕎","🔯","♈","♉","♊","♋","♌","♍","♎","♏","♐","♑","♒","♓","⛎","🔀","🔁","🔂","▶️","⏩","⏭️","⏯️","◀️","⏪","⏮️","🔼","⏫","🔽","⏬","⏸️","⏹️","⏺️","⏏️","🎦","🔅","🔆","📶","📳","📴","♀️","♂️","⚧️","✖️","➕","➖","➗","🟰","♾️","‼️","⁉️","❓","❔","❕","❗","〰️","💱","💲","⚕️","♻️","⚜️","🔱","📛","🔰","⭕","✅","☑️","✔️","❌","❎","➰","➿","〽️","✳️","✴️","❇️","©️","®️","™️","#️⃣","*️⃣","0️⃣","1️⃣","2️⃣","3️⃣","4️⃣","5️⃣","6️⃣","7️⃣","8️⃣","9️⃣","🔟","🔠","🔡","🔢","🔣","🔤","🅰️","🆎","🅱️","🆑","🆒","🆓","ℹ️","🆔","Ⓜ️","🆕","🆖","🅾️","🆗","🅿️","🆘","🆙","🆚","🈁","🈂️","🈷️","🈶","🈯","🉐","🈹","🈚","🈲","🉑","🈸","🈴","🈳","㊗️","㊙️","🈺","🈵","🔴","🟠","🟡","🟢","🔵","🟣","🟤","⚫","⚪","🟥","🟧","🟨","🟩","🟦","🟪","🟫","⬛","⬜","◼️","◻️","◾","◽","▪️","▫️","🔶","🔷","🔸","🔹","🔺","🔻","💠","🔘","🔳","🔲"],"🏁":["🏁","🚩","🎌","🏴","🏳️","🏳️‍🌈","🏳️‍⚧️","🏴‍☠️","🇦🇨","🇦🇩","🇦🇪","🇦🇫","🇦🇬","🇦🇮","🇦🇱","🇦🇲","🇦🇴","🇦🇶","🇦🇷","🇦🇸","🇦🇹","🇦🇺","🇦🇼","🇦🇽","🇦🇿","🇧🇦","🇧🇧","🇧🇩","🇧🇪","🇧🇫","🇧🇬","🇧🇭","🇧🇮","🇧🇯","🇧🇱","🇧🇲","🇧🇳","🇧🇴","🇧🇶","🇧🇷","🇧🇸","🇧🇹","🇧🇻","🇧🇼","🇧🇾","🇧🇿","🇨🇦","🇨🇨","🇨🇩","🇨🇫","🇨🇬","🇨🇭","🇨🇮","🇨🇰","🇨🇱","🇨🇲","🇨🇳","🇨🇴","🇨🇵","🇨🇷","🇨🇺","🇨🇻","🇨🇼","🇨🇽","🇨🇾","🇨🇿","🇩🇪","🇩🇬","🇩🇯","🇩🇰","🇩🇲","🇩🇴","🇩🇿","🇪🇦","🇪🇨","🇪🇪","🇪🇬","🇪🇭","🇪🇷","🇪🇸","🇪🇹","🇪🇺","🇫🇮","🇫🇯","🇫🇰","🇫🇲","🇫🇴","🇫🇷","🇬🇦","🇬🇧","🇬🇩","🇬🇪","🇬🇫","🇬🇬","🇬🇭","🇬🇮","🇬🇱","🇬🇲","🇬🇳","🇬🇵","🇬🇶","🇬🇷","🇬🇸","🇬🇹","🇬🇺","🇬🇼","🇬🇾","🇭🇰","🇭🇲","🇭🇳","🇭🇷","🇭🇹","🇭🇺","🇮🇨","🇮🇩","🇮🇪","🇮🇱","🇮🇲","🇮🇳","🇮🇴","🇮🇶","🇮🇷","🇮🇸","🇮🇹","🇯🇪","🇯🇲","🇯🇴","🇯🇵","🇰🇪","🇰🇬","🇰🇭","🇰🇮","🇰🇲","🇰🇳","🇰🇵","🇰🇷","🇰🇼","🇰🇾","🇰🇿","🇱🇦","🇱🇧","🇱🇨","🇱🇮","🇱🇰","🇱🇷","🇱🇸","🇱🇹","🇱🇺","🇱🇻","🇱🇾","🇲🇦","🇲🇨","🇲🇩","🇲🇪","🇲🇫","🇲🇬","🇲🇭","🇲🇰","🇲🇱","🇲🇲","🇲🇳","🇲🇴","🇲🇵","🇲🇶","🇲🇷","🇲🇸","🇲🇹","🇲🇺","🇲🇻","🇲🇼","🇲🇽","🇲🇾","🇲🇿","🇳🇦","🇳🇨","🇳🇪","🇳🇫","🇳🇬","🇳🇮","🇳🇱","🇳🇴","🇳🇵","🇳🇷","🇳🇺","🇳🇿","🇴🇲","🇵🇦","🇵🇪","🇵🇫","🇵🇬","🇵🇭","🇵🇰","🇵🇱","🇵🇲","🇵🇳","🇵🇷","🇵🇸","🇵🇹","🇵🇼","🇵🇾","🇶🇦","🇷🇪","🇷🇴","🇷🇸","🇷🇺","🇷🇼","🇸🇦","🇸🇧","🇸🇨","🇸🇩","🇸🇪","🇸🇬","🇸🇭","🇸🇮","🇸🇯","🇸🇰","🇸🇱","🇸🇲","🇸🇳","🇸🇴","🇸🇷","🇸🇸","🇸🇹","🇸🇻","🇸🇽","🇸🇾","🇸🇿","🇹🇦","🇹🇨","🇹🇩","🇹🇫","🇹🇬","🇹🇭","🇹🇯","🇹🇰","🇹🇱","🇹🇲","🇹🇳","🇹🇴","🇹🇷","🇹🇹","🇹🇻","🇹🇼","🇹🇿","🇺🇦","🇺🇬","🇺🇲","🇺🇳","🇺🇸","🇺🇾","🇺🇿","🇻🇦","🇻🇨","🇻🇪","🇻🇬","🇻🇮","🇻🇳","🇻🇺","🇼🇫","🇼🇸","🇽🇰","🇾🇪","🇾🇹","🇿🇦","🇿🇲","🇿🇼","🏴󠁧󠁢󠁥󠁮󠁧󠁿","🏴󠁧󠁢󠁳󠁣󠁴󠁿","🏴󠁧󠁢󠁷󠁬󠁳󠁿"]}`, + "map_provider": "openstreetmap", + "map_google_tile_type": "regular", + "mime_mapping": `{".xlsx":"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",".xltx":"application/vnd.openxmlformats-officedocument.spreadsheetml.template",".potx":"application/vnd.openxmlformats-officedocument.presentationml.template",".ppsx":"application/vnd.openxmlformats-officedocument.presentationml.slideshow",".pptx":"application/vnd.openxmlformats-officedocument.presentationml.presentation",".sldx":"application/vnd.openxmlformats-officedocument.presentationml.slide",".docx":"application/vnd.openxmlformats-officedocument.wordprocessingml.document",".dotx":"application/vnd.openxmlformats-officedocument.wordprocessingml.template",".xlam":"application/vnd.ms-excel.addin.macroEnabled.12",".xlsb":"application/vnd.ms-excel.sheet.binary.macroEnabled.12",".apk":"application/vnd.android.package-archive",".hqx":"application/mac-binhex40",".cpt":"application/mac-compactpro",".doc":"application/msword",".ogg":"application/ogg",".pdf":"application/pdf",".rtf":"text/rtf",".mif":"application/vnd.mif",".xls":"application/vnd.ms-excel",".ppt":"application/vnd.ms-powerpoint",".odc":"application/vnd.oasis.opendocument.chart",".odb":"application/vnd.oasis.opendocument.database",".odf":"application/vnd.oasis.opendocument.formula",".odg":"application/vnd.oasis.opendocument.graphics",".otg":"application/vnd.oasis.opendocument.graphics-template",".odi":"application/vnd.oasis.opendocument.image",".odp":"application/vnd.oasis.opendocument.presentation",".otp":"application/vnd.oasis.opendocument.presentation-template",".ods":"application/vnd.oasis.opendocument.spreadsheet",".ots":"application/vnd.oasis.opendocument.spreadsheet-template",".odt":"application/vnd.oasis.opendocument.text",".odm":"application/vnd.oasis.opendocument.text-master",".ott":"application/vnd.oasis.opendocument.text-template",".oth":"application/vnd.oasis.opendocument.text-web",".sxw":"application/vnd.sun.xml.writer",".stw":"application/vnd.sun.xml.writer.template",".sxc":"application/vnd.sun.xml.calc",".stc":"application/vnd.sun.xml.calc.template",".sxd":"application/vnd.sun.xml.draw",".std":"application/vnd.sun.xml.draw.template",".sxi":"application/vnd.sun.xml.impress",".sti":"application/vnd.sun.xml.impress.template",".sxg":"application/vnd.sun.xml.writer.global",".sxm":"application/vnd.sun.xml.math",".sis":"application/vnd.symbian.install",".wbxml":"application/vnd.wap.wbxml",".wmlc":"application/vnd.wap.wmlc",".wmlsc":"application/vnd.wap.wmlscriptc",".bcpio":"application/x-bcpio",".torrent":"application/x-bittorrent",".bz2":"application/x-bzip2",".vcd":"application/x-cdlink",".pgn":"application/x-chess-pgn",".cpio":"application/x-cpio",".csh":"application/x-csh",".dvi":"application/x-dvi",".spl":"application/x-futuresplash",".gtar":"application/x-gtar",".hdf":"application/x-hdf",".jar":"application/x-java-archive",".jnlp":"application/x-java-jnlp-file",".js":"application/x-javascript",".ksp":"application/x-kspread",".chrt":"application/x-kchart",".kil":"application/x-killustrator",".latex":"application/x-latex",".rpm":"application/x-rpm",".sh":"application/x-sh",".shar":"application/x-shar",".swf":"application/x-shockwave-flash",".sit":"application/x-stuffit",".sv4cpio":"application/x-sv4cpio",".sv4crc":"application/x-sv4crc",".tar":"application/x-tar",".tcl":"application/x-tcl",".tex":"application/x-tex",".man":"application/x-troff-man",".me":"application/x-troff-me",".ms":"application/x-troff-ms",".ustar":"application/x-ustar",".src":"application/x-wais-source",".zip":"application/zip",".m3u":"audio/x-mpegurl",".ra":"audio/x-pn-realaudio",".wav":"audio/x-wav",".wma":"audio/x-ms-wma",".wax":"audio/x-ms-wax",".pdb":"chemical/x-pdb",".xyz":"chemical/x-xyz",".bmp":"image/bmp",".gif":"image/gif",".ief":"image/ief",".png":"image/png",".wbmp":"image/vnd.wap.wbmp",".ras":"image/x-cmu-raster",".pnm":"image/x-portable-anymap",".pbm":"image/x-portable-bitmap",".pgm":"image/x-portable-graymap",".ppm":"image/x-portable-pixmap",".rgb":"image/x-rgb",".xbm":"image/x-xbitmap",".xpm":"image/x-xpixmap",".xwd":"image/x-xwindowdump",".css":"text/css",".rtx":"text/richtext",".tsv":"text/tab-separated-values",".jad":"text/vnd.sun.j2me.app-descriptor",".wml":"text/vnd.wap.wml",".wmls":"text/vnd.wap.wmlscript",".etx":"text/x-setext",".mxu":"video/vnd.mpegurl",".flv":"video/x-flv",".wm":"video/x-ms-wm",".wmv":"video/x-ms-wmv",".wmx":"video/x-ms-wmx",".wvx":"video/x-ms-wvx",".avi":"video/x-msvideo",".movie":"video/x-sgi-movie",".ice":"x-conference/x-cooltalk",".3gp":"video/3gpp",".ai":"application/postscript",".aif":"audio/x-aiff",".aifc":"audio/x-aiff",".aiff":"audio/x-aiff",".asc":"text/plain",".atom":"application/atom+xml",".au":"audio/basic",".bin":"application/octet-stream",".cdf":"application/x-netcdf",".cgm":"image/cgm",".class":"application/octet-stream",".dcr":"application/x-director",".dif":"video/x-dv",".dir":"application/x-director",".djv":"image/vnd.djvu",".djvu":"image/vnd.djvu",".dll":"application/octet-stream",".dmg":"application/octet-stream",".dms":"application/octet-stream",".dtd":"application/xml-dtd",".dv":"video/x-dv",".dxr":"application/x-director",".eps":"application/postscript",".exe":"application/octet-stream",".ez":"application/andrew-inset",".gram":"application/srgs",".grxml":"application/srgs+xml",".gz":"application/x-gzip",".htm":"text/html",".html":"text/html",".ico":"image/x-icon",".ics":"text/calendar",".ifb":"text/calendar",".iges":"model/iges",".igs":"model/iges",".jp2":"image/jp2",".jpe":"image/jpeg",".jpeg":"image/jpeg",".jpg":"image/jpeg",".kar":"audio/midi",".lha":"application/octet-stream",".lzh":"application/octet-stream",".m4a":"audio/mp4a-latm",".m4p":"audio/mp4a-latm",".m4u":"video/vnd.mpegurl",".m4v":"video/x-m4v",".mac":"image/x-macpaint",".mathml":"application/mathml+xml",".mesh":"model/mesh",".mid":"audio/midi",".midi":"audio/midi",".mov":"video/quicktime",".mp2":"audio/mpeg",".mp3":"audio/mpeg",".mp4":"video/mp4",".mpe":"video/mpeg",".mpeg":"video/mpeg",".mpg":"video/mpeg",".mpga":"audio/mpeg",".msh":"model/mesh",".nc":"application/x-netcdf",".oda":"application/oda",".ogv":"video/ogv",".pct":"image/pict",".pic":"image/pict",".pict":"image/pict",".pnt":"image/x-macpaint",".pntg":"image/x-macpaint",".ps":"application/postscript",".qt":"video/quicktime",".qti":"image/x-quicktime",".qtif":"image/x-quicktime",".ram":"audio/x-pn-realaudio",".rdf":"application/rdf+xml",".rm":"application/vnd.rn-realmedia",".roff":"application/x-troff",".sgm":"text/sgml",".sgml":"text/sgml",".silo":"model/mesh",".skd":"application/x-koan",".skm":"application/x-koan",".skp":"application/x-koan",".skt":"application/x-koan",".smi":"application/smil",".smil":"application/smil",".snd":"audio/basic",".so":"application/octet-stream",".svg":"image/svg+xml",".t":"application/x-troff",".texi":"application/x-texinfo",".texinfo":"application/x-texinfo",".tif":"image/tiff",".tiff":"image/tiff",".tr":"application/x-troff",".txt":"text/plain",".vrml":"model/vrml",".vxml":"application/voicexml+xml",".webm":"video/webm",".wrl":"model/vrml",".xht":"application/xhtml+xml",".xhtml":"application/xhtml+xml",".xml":"application/xml",".xsl":"application/xml",".xslt":"application/xslt+xml",".xul":"application/vnd.mozilla.xul+xml",".webp":"image/webp",".323":"text/h323",".aab":"application/x-authoware-bin",".aam":"application/x-authoware-map",".aas":"application/x-authoware-seg",".acx":"application/internet-property-stream",".als":"audio/X-Alpha5",".amc":"application/x-mpeg",".ani":"application/octet-stream",".asd":"application/astound",".asf":"video/x-ms-asf",".asn":"application/astound",".asp":"application/x-asap",".asr":"video/x-ms-asf",".asx":"video/x-ms-asf",".avb":"application/octet-stream",".awb":"audio/amr-wb",".axs":"application/olescript",".bas":"text/plain",".bin ":"application/octet-stream",".bld":"application/bld",".bld2":"application/bld2",".bpk":"application/octet-stream",".c":"text/plain",".cal":"image/x-cals",".cat":"application/vnd.ms-pkiseccat",".ccn":"application/x-cnc",".cco":"application/x-cocoa",".cer":"application/x-x509-ca-cert",".cgi":"magnus-internal/cgi",".chat":"application/x-chat",".clp":"application/x-msclip",".cmx":"image/x-cmx",".co":"application/x-cult3d-object",".cod":"image/cis-cod",".conf":"text/plain",".cpp":"text/plain",".crd":"application/x-mscardfile",".crl":"application/pkix-crl",".crt":"application/x-x509-ca-cert",".csm":"chemical/x-csml",".csml":"chemical/x-csml",".cur":"application/octet-stream",".dcm":"x-lml/x-evm",".dcx":"image/x-dcx",".der":"application/x-x509-ca-cert",".dhtml":"text/html",".dot":"application/msword",".dwf":"drawing/x-dwf",".dwg":"application/x-autocad",".dxf":"application/x-autocad",".ebk":"application/x-expandedbook",".emb":"chemical/x-embl-dl-nucleotide",".embl":"chemical/x-embl-dl-nucleotide",".epub":"application/epub+zip",".eri":"image/x-eri",".es":"audio/echospeech",".esl":"audio/echospeech",".etc":"application/x-earthtime",".evm":"x-lml/x-evm",".evy":"application/envoy",".fh4":"image/x-freehand",".fh5":"image/x-freehand",".fhc":"image/x-freehand",".fif":"application/fractals",".flr":"x-world/x-vrml",".fm":"application/x-maker",".fpx":"image/x-fpx",".fvi":"video/isivideo",".gau":"chemical/x-gaussian-input",".gca":"application/x-gca-compressed",".gdb":"x-lml/x-gdb",".gps":"application/x-gps",".h":"text/plain",".hdm":"text/x-hdml",".hdml":"text/x-hdml",".hlp":"application/winhlp",".hta":"application/hta",".htc":"text/x-component",".hts":"text/html",".htt":"text/webviewhtml",".ifm":"image/gif",".ifs":"image/ifs",".iii":"application/x-iphone",".imy":"audio/melody",".ins":"application/x-internet-signup",".ips":"application/x-ipscript",".ipx":"application/x-ipix",".isp":"application/x-internet-signup",".it":"audio/x-mod",".itz":"audio/x-mod",".ivr":"i-world/i-vrml",".j2k":"image/j2k",".jam":"application/x-jam",".java":"text/plain",".jfif":"image/pipeg",".jpz":"image/jpeg",".jwc":"application/jwc",".kjx":"application/x-kjx",".lak":"x-lml/x-lak",".lcc":"application/fastman",".lcl":"application/x-digitalloca",".lcr":"application/x-digitalloca",".lgh":"application/lgh",".lml":"x-lml/x-lml",".lmlpack":"x-lml/x-lmlpack",".log":"text/plain",".lsf":"video/x-la-asf",".lsx":"video/x-la-asf",".m13":"application/x-msmediaview",".m14":"application/x-msmediaview",".m15":"audio/x-mod",".m3url":"audio/x-mpegurl",".m4b":"audio/mp4a-latm",".ma1":"audio/ma1",".ma2":"audio/ma2",".ma3":"audio/ma3",".ma5":"audio/ma5",".map":"magnus-internal/imagemap",".mbd":"application/mbedlet",".mct":"application/x-mascot",".mdb":"application/x-msaccess",".mdz":"audio/x-mod",".mel":"text/x-vmel",".mht":"message/rfc822",".mhtml":"message/rfc822",".mi":"application/x-mif",".mil":"image/x-cals",".mio":"audio/x-mio",".mmf":"application/x-skt-lbs",".mng":"video/x-mng",".mny":"application/x-msmoney",".moc":"application/x-mocha",".mocha":"application/x-mocha",".mod":"audio/x-mod",".mof":"application/x-yumekara",".mol":"chemical/x-mdl-molfile",".mop":"chemical/x-mopac-input",".mpa":"video/mpeg",".mpc":"application/vnd.mpohun.certificate",".mpg4":"video/mp4",".mpn":"application/vnd.mophun.application",".mpp":"application/vnd.ms-project",".mps":"application/x-mapserver",".mpv2":"video/mpeg",".mrl":"text/x-mrml",".mrm":"application/x-mrm",".msg":"application/vnd.ms-outlook",".mts":"application/metastream",".mtx":"application/metastream",".mtz":"application/metastream",".mvb":"application/x-msmediaview",".mzv":"application/metastream",".nar":"application/zip",".nbmp":"image/nbmp",".ndb":"x-lml/x-ndb",".ndwn":"application/ndwn",".nif":"application/x-nif",".nmz":"application/x-scream",".nokia-op-logo":"image/vnd.nok-oplogo-color",".npx":"application/x-netfpx",".nsnd":"audio/nsnd",".nva":"application/x-neva1",".nws":"message/rfc822",".oom":"application/x-AtlasMate-Plugin",".p10":"application/pkcs10",".p12":"application/x-pkcs12",".p7b":"application/x-pkcs7-certificates",".p7c":"application/x-pkcs7-mime",".p7m":"application/x-pkcs7-mime",".p7r":"application/x-pkcs7-certreqresp",".p7s":"application/x-pkcs7-signature",".pac":"audio/x-pac",".pae":"audio/x-epac",".pan":"application/x-pan",".pcx":"image/x-pcx",".pda":"image/x-pda",".pfr":"application/font-tdpfr",".pfx":"application/x-pkcs12",".pko":"application/ynd.ms-pkipko",".pm":"application/x-perl",".pma":"application/x-perfmon",".pmc":"application/x-perfmon",".pmd":"application/x-pmd",".pml":"application/x-perfmon",".pmr":"application/x-perfmon",".pmw":"application/x-perfmon",".pnz":"image/png",".pot,":"application/vnd.ms-powerpoint",".pps":"application/vnd.ms-powerpoint",".pqf":"application/x-cprplayer",".pqi":"application/cprplayer",".prc":"application/x-prc",".prf":"application/pics-rules",".prop":"text/plain",".proxy":"application/x-ns-proxy-autoconfig",".ptlk":"application/listenup",".pub":"application/x-mspublisher",".pvx":"video/x-pv-pvx",".qcp":"audio/vnd.qcelp",".r3t":"text/vnd.rn-realtext3d",".rar":"application/octet-stream",".rc":"text/plain",".rf":"image/vnd.rn-realflash",".rlf":"application/x-richlink",".rmf":"audio/x-rmf",".rmi":"audio/mid",".rmm":"audio/x-pn-realaudio",".rmvb":"audio/x-pn-realaudio",".rnx":"application/vnd.rn-realplayer",".rp":"image/vnd.rn-realpix",".rt":"text/vnd.rn-realtext",".rte":"x-lml/x-gps",".rtg":"application/metastream",".rv":"video/vnd.rn-realvideo",".rwc":"application/x-rogerwilco",".s3m":"audio/x-mod",".s3z":"audio/x-mod",".sca":"application/x-supercard",".scd":"application/x-msschedule",".sct":"text/scriptlet",".sdf":"application/e-score",".sea":"application/x-stuffit",".setpay":"application/set-payment_old-initiation",".setreg":"application/set-registration-initiation",".shtml":"text/html",".shtm":"text/html",".shw":"application/presentations",".si6":"image/si6",".si7":"image/vnd.stiwap.sis",".si9":"image/vnd.lgtwap.sis",".slc":"application/x-salsa",".smd":"audio/x-smd",".smp":"application/studiom",".smz":"audio/x-smd",".spc":"application/x-pkcs7-certificates",".spr":"application/x-sprite",".sprite":"application/x-sprite",".sdp":"application/sdp",".spt":"application/x-spt",".sst":"application/vnd.ms-pkicertstore",".stk":"application/hyperstudio",".stl":"application/vnd.ms-pkistl",".stm":"text/html",".svf":"image/vnd",".svh":"image/svh",".svr":"x-world/x-svr",".swfl":"application/x-shockwave-flash",".tad":"application/octet-stream",".talk":"text/x-speech",".taz":"application/x-tar",".tbp":"application/x-timbuktu",".tbt":"application/x-timbuktu",".tgz":"application/x-compressed",".thm":"application/vnd.eri.thm",".tki":"application/x-tkined",".tkined":"application/x-tkined",".toc":"application/toc",".toy":"image/toy",".trk":"x-lml/x-gps",".trm":"application/x-msterminal",".tsi":"audio/tsplayer",".tsp":"application/dsptype",".ttf":"application/octet-stream",".ttz":"application/t-time",".uls":"text/iuls",".ult":"audio/x-mod",".uu":"application/x-uuencode",".uue":"application/x-uuencode",".vcf":"text/x-vcard",".vdo":"video/vdo",".vib":"audio/vib",".viv":"video/vivo",".vivo":"video/vivo",".vmd":"application/vocaltec-media-desc",".vmf":"application/vocaltec-media-file",".vmi":"application/x-dreamcast-vms-info",".vms":"application/x-dreamcast-vms",".vox":"audio/voxware",".vqe":"audio/x-twinvq-plugin",".vqf":"audio/x-twinvq",".vql":"audio/x-twinvq",".vre":"x-world/x-vream",".vrt":"x-world/x-vrt",".vrw":"x-world/x-vream",".vts":"workbook/formulaone",".wcm":"application/vnd.ms-works",".wdb":"application/vnd.ms-works",".web":"application/vnd.xara",".wi":"image/wavelet",".wis":"application/x-InstallShield",".wks":"application/vnd.ms-works",".wmd":"application/x-ms-wmd",".wmf":"application/x-msmetafile",".wmlscript":"text/vnd.wap.wmlscript",".wmz":"application/x-ms-wmz",".wpng":"image/x-up-wpng",".wps":"application/vnd.ms-works",".wpt":"x-lml/x-gps",".wri":"application/x-mswrite",".wrz":"x-world/x-vrml",".ws":"text/vnd.wap.wmlscript",".wsc":"application/vnd.wap.wmlscriptc",".wv":"video/wavelet",".wxl":"application/x-wxl",".x-gzip":"application/x-gzip",".xaf":"x-world/x-vrml",".xar":"application/vnd.xara",".xdm":"application/x-xdma",".xdma":"application/x-xdma",".xdw":"application/vnd.fujixerox.docuworks",".xhtm":"application/xhtml+xml",".xla":"application/vnd.ms-excel",".xlc":"application/vnd.ms-excel",".xll":"application/x-excel",".xlm":"application/vnd.ms-excel",".xlt":"application/vnd.ms-excel",".xlw":"application/vnd.ms-excel",".xm":"audio/x-mod",".xmz":"audio/x-mod",".xof":"x-world/x-vrml",".xpi":"application/x-xpinstall",".xsit":"text/xml",".yz1":"application/x-yz1",".z":"application/x-compress",".zac":"application/x-zaurus-zac",".json":"application/json"}`, + "file_viewers": `[{"viewers":[{"id":"music","type":"builtin","action":"view","display_name":"fileManager.musicPlayer","exts":["mp3","ogg","wav","flac","m4a"]},{"id":"epub","type":"builtin","action":"view","display_name":"fileManager.epubViewer","exts":["epub"]},{"id":"googledocs","type":"custom","action":"view","display_name":"fileManager.googledocs","icon":"/static/img/viewers/gdrive.png","url":"https://docs.google.com/gview?url={$src}&embedded=true","exts":["jpeg","png","gif","tiff","bmp","webm","mpeg4","3gpp","mov","avi","mpegps","wmv","flv","txt","css","html","php","c","cpp","h","hpp","js","doc","docx","xls","xlsx","ppt","pptx","pdf","pages","ai","psd","tiff","dxf","svg","eps","ps","ttf","xps"],"max_size":26214400},{"id":"m365online","type":"custom","action":"view","display_name":"fileManager.m365viewer","icon":"/static/img/viewers/m365.svg","url":"https://view.officeapps.live.com/op/view.aspx?src={$src}","exts":["doc","docx","docm","dotm","dotx","xlsx","xlsb","xls","xlsm","pptx","ppsx","ppt","pps","pptm","potm","ppam","potx","ppsm"],"max_size":10485760},{"id":"pdf","type":"builtin","action":"view","display_name":"fileManager.pdfViewer","exts":["pdf"]},{"id":"video","type":"builtin","action":"view","icon":"/static/img/viewers/artplayer.png","display_name":"Artplayer","exts":["mp4","mkv","webm","avi","m3u8","mov","m3u8"]},{"id":"markdown","type":"builtin","action":"edit","display_name":"fileManager.markdownEditor","exts":["md"],"templates":[{"ext":"md","display_name":"Markdown"}]},{"id":"drawio","type":"builtin","action":"edit","icon":"/static/img/viewers/drawio.svg","display_name":"draw.io","exts":["drawio","dwb"],"props":{"host":"https://embed.diagrams.net"},"templates":[{"ext":"drawio","display_name":"fileManager.diagram"},{"ext":"dwb","display_name":"fileManager.whiteboard"}]},{"id":"image","type":"builtin","action":"edit","display_name":"fileManager.imageViewer","exts":["bmp","png","gif","jpg","jpeg","svg","webp","heic","heif"]},{"id":"monaco","type":"builtin","action":"edit","icon":"/static/img/viewers/monaco.svg","display_name":"fileManager.monacoEditor","exts":["md","txt","json","php","py","bat","c","h","cpp","hpp","cs","css","dockerfile","go","html","htm","ini","java","js","jsx","less","lua","sh","sql","xml","yaml"],"templates":[{"ext":"txt","display_name":"fileManager.text"}]},{"id":"photopea","type":"builtin","icon":"/static/img/viewers/photopea.png","action":"edit","display_name":"Photopea","exts":["psd","ai","indd","xcf","xd","fig","kri","clip","pxd","pxz","cdr","ufo","afphoyo","svg","esp","pdf","pdn","wmf","emf","png","jpg","jpeg","gif","webp","ico","icns","bmp","avif","heic","jxl","ppm","pgm","pbm","tiff","dds","iff","anim","tga","dng","nef","cr2","cr3","arw","rw2","raf","orf","gpr","3fr","fff"]}]}]`, + "logto_enabled": "0", + "logto_config": `{"direct_sign_in":true,"display_name":"vas.sso"}`, + "qq_login": `0`, + "qq_login_config": `{"direct_sign_in":false}`, + "license": "", +} diff --git a/inventory/share.go b/inventory/share.go new file mode 100644 index 00000000..a43c08b7 --- /dev/null +++ b/inventory/share.go @@ -0,0 +1,401 @@ +package inventory + +import ( + "context" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/predicate" + "github.com/cloudreve/Cloudreve/v4/ent/share" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/samber/lo" +) + +type ( + // Ctx keys for eager loading options. + LoadShareFile struct{} + LoadShareUser struct{} +) + +var ( + ErrShareLinkExpired = fmt.Errorf("share link expired") + ErrOwnerInactive = fmt.Errorf("owner is inactive") + ErrSourceFileInvalid = fmt.Errorf("source file is deleted") +) + +type ( + ShareClient interface { + TxOperator + // GetByID returns the share with given id. + GetByID(ctx context.Context, id int) (*ent.Share, error) + // GetByIDUser returns the share with given id and user id. + GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) + // GetByHashID returns the share with given hash id. + GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) + // Upsert creates or update a new share record. + Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) + // Viewed increase the view count of the share. + Viewed(ctx context.Context, share *ent.Share) error + // Downloaded increase the download count of the share. + Downloaded(ctx context.Context, share *ent.Share) error + // Delete deletes the share. + Delete(ctx context.Context, shareId int) error + // List returns a list of shares with the given args. + List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) + // CountByTimeRange counts the number of shares created in the given time range. + CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) + // DeleteBatch deletes the shares with the given ids. + DeleteBatch(ctx context.Context, shareIds []int) error + } + + CreateShareParams struct { + Existed *ent.Share + Password string + RemainDownloads int + Expires *time.Time + OwnerID int + FileID int + } + + ListShareArgs struct { + *PaginationArgs + UserID int + FileID int + PublicOnly bool + } + ListShareResult struct { + *PaginationResults + Shares []*ent.Share + } +) + +func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient { + return &shareClient{ + client: client, + hasher: hasher, + maxSQlParam: sqlParamLimit(dbType), + } +} + +type shareClient struct { + maxSQlParam int + client *ent.Client + hasher hashid.Encoder +} + +func (c *shareClient) SetClient(newClient *ent.Client) TxOperator { + return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam} +} + +func (c *shareClient) GetClient() *ent.Client { + return c.client +} + +func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) { + if start == nil || end == nil { + return c.client.Share.Query().Count(ctx) + } + + return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx) +} + +func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) { + if params.Existed != nil { + createQuery := c.client.Share. + UpdateOne(params.Existed) + if params.RemainDownloads > 0 { + createQuery.SetRemainDownloads(params.RemainDownloads) + } else { + createQuery.ClearRemainDownloads() + } + if params.Expires != nil { + createQuery.SetNillableExpires(params.Expires) + } else { + createQuery.ClearExpires() + } + + return createQuery.Save(ctx) + } + + query := c.client.Share. + Create(). + SetUserID(params.OwnerID). + SetFileID(params.FileID) + if params.Password != "" { + query.SetPassword(params.Password) + } + if params.RemainDownloads > 0 { + query.SetRemainDownloads(params.RemainDownloads) + } + if params.Expires != nil { + query.SetNillableExpires(params.Expires) + } + + return query.Save(ctx) +} + +func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) { + id, err := c.hasher.Decode(idRaw, hashid.ShareID) + if err != nil { + return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err) + } + + return c.GetByID(ctx, id) +} + +func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) { + s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query share %d: %w", id, err) + } + + return s, nil +} + +func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) { + s, err := withShareEagerLoading(ctx, c.client.Share.Query(). + Where(share.ID(id))). + Where(share.HasUserWith(user.ID(uid))).First(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query share %d: %w", id, err) + } + + return s, nil +} + +func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error { + _, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx) + return err +} + +func (c *shareClient) Delete(ctx context.Context, shareId int) error { + return c.client.Share.DeleteOneID(shareId).Exec(ctx) +} + +// Viewed increments the view count of the share. +func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error { + _, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx) + return err +} + +// Downloaded increments the download count of the share. +func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error { + stm := c.client.Share. + UpdateOneID(share.ID). + AddDownloads(1) + if share.RemainDownloads != nil && *share.RemainDownloads >= 0 { + stm.AddRemainDownloads(-1) + } + _, err := stm.Save(ctx) + return err +} + +func IsValidShare(share *ent.Share) error { + // Check if share is expired + if err := IsShareExpired(share); err != nil { + return err + } + + // Check owner status + owner, err := share.Edges.UserOrErr() + if err != nil || owner.Status != user.StatusActive { + // Owner already deleted, or not active. + return ErrOwnerInactive + } + + // Check source file status + file, err := share.Edges.FileOrErr() + if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID { + // Source file already deleted + return ErrSourceFileInvalid + } + + return nil +} + +func IsShareExpired(share *ent.Share) error { + // Check if share is expired + if (share.Expires != nil && share.Expires.Before(time.Now())) || + (share.RemainDownloads != nil && *share.RemainDownloads <= 0) { + return ErrShareLinkExpired + } + + return nil +} + +func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) { + rawQuery := c.listQuery(args) + query := withShareEagerLoading(ctx, rawQuery) + + var ( + shares []*ent.Share + err error + paginationRes *PaginationResults + ) + if args.UseCursorPagination { + shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10) + } else { + shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10) + } + + if err != nil { + return nil, fmt.Errorf("query failed with paginiation: %w", err) + } + + return &ListShareResult{ + Shares: shares, + PaginationResults: paginationRes, + }, nil +} + +func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) { + pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) + query.Order(getShareOrderOption(args)...) + + var ( + pageToken *PageToken + err error + ) + if args.PageToken != "" { + pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID) + if err != nil { + return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err) + } + } + queryPaged := getShareCursorQuery(args, pageToken, query) + + // Use page size + 1 to determine if there are more items to come + queryPaged.Limit(pageSize + 1) + + logs, err := queryPaged. + All(ctx) + if err != nil { + return nil, nil, err + } + + // More items to come + nextTokenStr := "" + if len(logs) > pageSize { + lastItem := logs[len(logs)-2] + nextToken, err := getShareNextPageToken(c.hasher, lastItem, args) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate next page token: %w", err) + } + + nextTokenStr = nextToken + } + + return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{ + PageSize: pageSize, + NextPageToken: nextTokenStr, + IsCursor: true, + }, nil +} + +func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) { + pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) + query.Order(getShareOrderOption(args)...) + + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, nil, err + } + + logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, nil, err + } + + return logs, &PaginationResults{ + PageSize: pageSize, + TotalItems: total, + Page: args.Page, + }, nil +} + +func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery { + query := c.client.Share.Query() + if args.UserID > 0 { + query.Where(share.HasUserWith(user.ID(args.UserID))) + } + + if args.PublicOnly { + query.Where(share.PasswordIsNil()) + } + + if args.FileID > 0 { + query.Where(share.HasFileWith(file.ID(args.FileID))) + } + + return query +} + +// getShareNextPageToken returns the next page token for the given last share. +func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) { + token := &PageToken{ + ID: last.ID, + } + + return token.Encode(hasher, hashid.EncodeShareID) +} + +func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery { + o := &sql.OrderTermOptions{} + getOrderTerm(args.Order)(o) + + predicates, ok := shareCursorQuery[args.OrderBy] + if !ok { + predicates = shareCursorQuery[share.FieldID] + } + + if token != nil { + query.Where(predicates[o.Desc](token)) + } + + return query +} + +var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{ + share.FieldID: { + true: func(token *PageToken) predicate.Share { + return share.IDLT(token.ID) + }, + false: func(token *PageToken) predicate.Share { + return share.IDGT(token.ID) + }, + }, +} + +func getShareOrderOption(args *ListShareArgs) []share.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case share.FieldViews: + return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)} + case share.FieldDownloads: + return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)} + case share.FieldRemainDownloads: + return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)} + default: + return []share.OrderOption{share.ByID(orderTerm)} + } +} + +func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery { + if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v { + q.WithFile(func(q *ent.FileQuery) { + withFileEagerLoading(ctx, q) + }) + } + if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v { + q.WithUser(func(q *ent.UserQuery) { + withUserEagerLoading(ctx, q) + }) + } + + return q +} diff --git a/inventory/task.go b/inventory/task.go new file mode 100644 index 00000000..5e64e614 --- /dev/null +++ b/inventory/task.go @@ -0,0 +1,314 @@ +package inventory + +import ( + "context" + "fmt" + + "entgo.io/ent/dialect/sql" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +type ( + // Ctx keys for eager loading options. + LoadTaskUser struct{} + + TaskArgs struct { + Status task.Status + Type string + PublicState *types.TaskPublicState + PrivateState string + OwnerID int + CorrelationID uuid.UUID + } +) + +type TaskClient interface { + TxOperator + // New creates a new task with the given args. + New(ctx context.Context, task *TaskArgs) (*ent.Task, error) + // Update updates the task with the given args. + Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) + // GetPendingTasks returns all pending tasks of given type. + GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) + // GetTaskByID returns the task with the given ID. + GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) + // SetCompleteByID sets the task with the given ID to complete. + SetCompleteByID(ctx context.Context, taskID int) error + // List returns a list of tasks with the given args. + List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) + // DeleteByIDs deletes the tasks with the given IDs. + DeleteByIDs(ctx context.Context, ids ...int) error +} + +type ( + ListTaskArgs struct { + *PaginationArgs + Types []string + Status []task.Status + UserID int + CorrelationID *uuid.UUID + } + + ListTaskResult struct { + *PaginationResults + Tasks []*ent.Task + } +) + +func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient { + return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher} +} + +type taskClient struct { + maxSQlParam int + hasher hashid.Encoder + client *ent.Client +} + +func (c *taskClient) SetClient(newClient *ent.Client) TxOperator { + return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher} +} + +func (c *taskClient) GetClient() *ent.Client { + return c.client +} + +func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) { + stm := c.client.Task. + Create(). + SetType(task.Type). + SetPublicState(task.PublicState) + if task.PrivateState != "" { + stm.SetPrivateState(task.PrivateState) + } + + if task.OwnerID != 0 { + stm.SetUserID(task.OwnerID) + } + + if task.Status != "" { + stm.SetStatus(task.Status) + } + + if task.CorrelationID.String() != uuid.Nil.String() { + stm.SetCorrelationID(task.CorrelationID) + } + + newTask, err := stm.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create task: %w", err) + } + + return newTask, nil +} + +func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error { + _, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx) + return err +} + +func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) { + stm := c.client.Task.UpdateOne(task). + SetPublicState(args.PublicState) + + task.PublicState = args.PublicState + + if task.PrivateState != "" { + stm.SetPrivateState(task.PrivateState) + task.PrivateState = args.PrivateState + } + + if task.Status != "" { + stm.SetStatus(args.Status) + task.Status = args.Status + } + + if err := stm.Exec(ctx); err != nil { + return nil, fmt.Errorf("failed to create task: %w", err) + } + + return task, nil +} + +func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) { + tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()). + Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)). + Where(task.TypeIn(taskType...)). + All(ctx) + if err != nil { + return nil, err + } + + // Anonymous user is not loaded by default, so we need to load it manually. + userClient := NewUserClient(c.client) + anonymous, err := userClient.AnonymousUser(ctx) + for _, t := range tasks { + if t.UserTasks == 0 { + if err != nil { + return nil, err + } + t.SetUser(anonymous) + } + } + + return tasks, nil +} + +func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) { + return withTaskEagerLoading(ctx, c.client.Task.Query()). + Where(task.ID(taskID)). + First(ctx) +} + +func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error { + _, err := c.client.Task.UpdateOneID(taskID). + SetStatus(task.StatusCompleted). + Save(ctx) + return err +} + +func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) { + q := c.client.Task.Query() + if args.UserID != 0 { + q.Where(task.UserTasks(args.UserID)) + } + + if args.Types != nil { + q.Where(task.TypeIn(args.Types...)) + } + + if args.Status != nil { + q.Where(task.StatusIn(args.Status...)) + } + + if args.CorrelationID != nil { + q.Where(task.CorrelationID(*args.CorrelationID)) + } + + q = withTaskEagerLoading(ctx, q) + var ( + tasks []*ent.Task + err error + paginationRes *PaginationResults + ) + + if args.UseCursorPagination { + tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1) + } else { + tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1) + } + + if err != nil { + return nil, fmt.Errorf("query failed with paginiation: %w", err) + } + + return &ListTaskResult{ + Tasks: tasks, + PaginationResults: paginationRes, + }, nil +} + +func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) { + pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) + query.Order(task.ByID(sql.OrderDesc())) + + var ( + pageToken *PageToken + err error + queryPaged = query + ) + if args.PageToken != "" { + pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID) + if err != nil { + return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err) + } + + queryPaged = query.Where(task.IDLT(pageToken.ID)) + } + + // Use page size + 1 to determine if there are more items to come + queryPaged.Limit(pageSize + 1) + + tasks, err := queryPaged. + All(ctx) + if err != nil { + return nil, nil, err + } + + // More items to come + nextTokenStr := "" + if len(tasks) > pageSize { + lastItem := tasks[len(tasks)-2] + nextToken, err := getTaskNextPageToken(c.hasher, lastItem) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate next page token: %w", err) + } + + nextTokenStr = nextToken + } + + return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{ + PageSize: pageSize, + NextPageToken: nextTokenStr, + IsCursor: true, + }, nil +} + +func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) { + pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin) + query.Order(getTaskOrderOption(args)...) + + // Count total items + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, nil, err + } + + logs, err := query. + Limit(pageSize). + Offset(args.Page * args.PageSize). + All(ctx) + if err != nil { + return nil, nil, err + } + + return logs, &PaginationResults{ + PageSize: pageSize, + TotalItems: total, + Page: args.Page, + }, nil + +} + +func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + default: + return []task.OrderOption{task.ByID(orderTerm)} + } +} + +// getTaskNextPageToken returns the next page token for the given last task. +func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) { + token := &PageToken{ + ID: last.ID, + } + + return token.Encode(hasher, hashid.EncodeTaskID) +} + +func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery { + if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v { + q.WithUser(func(q *ent.UserQuery) { + withUserEagerLoading(ctx, q) + }) + } + + return q +} diff --git a/inventory/tx.go b/inventory/tx.go new file mode 100644 index 00000000..e267047a --- /dev/null +++ b/inventory/tx.go @@ -0,0 +1,101 @@ +package inventory + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" +) + +type TxOperator interface { + SetClient(newClient *ent.Client) TxOperator + GetClient() *ent.Client +} + +type ( + Tx struct { + tx *ent.Tx + parent *Tx + inherited bool + finished bool + storageDiff StorageDiff + } + + // TxCtx is the context key for inherited transaction + TxCtx struct{} +) + +// AppendStorageDiff appends the given storage diff to the transaction. +func (t *Tx) AppendStorageDiff(diff StorageDiff) { + root := t + for root.inherited { + root = root.parent + } + + if root.storageDiff == nil { + root.storageDiff = diff + } else { + root.storageDiff.Merge(diff) + } +} + +// WithTx wraps the given inventory client with a transaction. +func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) { + var txClient *ent.Client + var txWrapper *Tx + + if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished { + txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited} + } else { + tx, err := c.GetClient().Tx(ctx) + if err != nil { + return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err) + } + + txWrapper = &Tx{inherited: false, tx: tx} + ctx = context.WithValue(ctx, TxCtx{}, txWrapper) + } + + txClient = txWrapper.tx.Client() + return c.SetClient(txClient).(T), txWrapper, ctx, nil +} + +func Rollback(tx *Tx) error { + if !tx.inherited { + tx.finished = true + return tx.tx.Rollback() + } + + return nil +} + +func commit(tx *Tx) (bool, error) { + if !tx.inherited { + tx.finished = true + return true, tx.tx.Commit() + } + return false, nil +} + +func Commit(tx *Tx) error { + _, err := commit(tx) + return err +} + +// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited. +func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error { + commited, err := commit(tx) + if err != nil { + return err + } + + if !commited { + return nil + } + + if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil { + l.Error("Failed to apply storage diff", "error", err) + } + + return nil +} diff --git a/inventory/types/types.go b/inventory/types/types.go new file mode 100644 index 00000000..9d77d581 --- /dev/null +++ b/inventory/types/types.go @@ -0,0 +1,224 @@ +package types + +import ( + "time" +) + +// UserSetting 用户其他配置 +type ( + UserSetting struct { + ProfileOff bool `json:"profile_off,omitempty"` + PreferredTheme string `json:"preferred_theme,omitempty"` + VersionRetention bool `json:"version_retention,omitempty"` + VersionRetentionExt []string `json:"version_retention_ext,omitempty"` + VersionRetentionMax int `json:"version_retention_max,omitempty"` + Pined []PinedFile `json:"pined,omitempty"` + Language string `json:"email_language,omitempty"` + } + + PinedFile struct { + Uri string `json:"uri"` + Name string `json:"name,omitempty"` + } + + // GroupSetting 用户组其他配置 + GroupSetting struct { + CompressSize int64 `json:"compress_size,omitempty"` // 可压缩大小 + DecompressSize int64 `json:"decompress_size,omitempty"` + RemoteDownloadOptions map[string]interface{} `json:"remote_download_options,omitempty"` // 离线下载用户组配置 + SourceBatchSize int `json:"source_batch,omitempty"` + Aria2BatchSize int `json:"aria2_batch,omitempty"` + MaxWalkedFiles int `json:"max_walked_files,omitempty"` + TrashRetention int `json:"trash_retention,omitempty"` + RedirectedSource bool `json:"redirected_source,omitempty"` + } + + // PolicySetting 非公有的存储策略属性 + PolicySetting struct { + // Upyun访问Token + Token string `json:"token"` + // 允许的文件扩展名 + FileType []string `json:"file_type"` + // OauthRedirect Oauth 重定向地址 + OauthRedirect string `json:"od_redirect,omitempty"` + // CustomProxy whether to use custom-proxy to get file content + CustomProxy bool `json:"custom_proxy,omitempty"` + // ProxyServer 反代地址 + ProxyServer string `json:"proxy_server,omitempty"` + // InternalProxy whether to use Cloudreve internal proxy to get file content + InternalProxy bool `json:"internal_proxy,omitempty"` + // OdDriver OneDrive 驱动器定位符 + OdDriver string `json:"od_driver,omitempty"` + // Region 区域代码 + Region string `json:"region,omitempty"` + // ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段 + ServerSideEndpoint string `json:"server_side_endpoint,omitempty"` + // 分片上传的分片大小 + ChunkSize int64 `json:"chunk_size,omitempty"` + // 每秒对存储端的 API 请求上限 + TPSLimit float64 `json:"tps_limit,omitempty"` + // 每秒 API 请求爆发上限 + TPSLimitBurst int `json:"tps_limit_burst,omitempty"` + // Set this to `true` to force the request to use path-style addressing, + // i.e., `http://s3.amazonaws.com/BUCKET/KEY ` + S3ForcePathStyle bool `json:"s3_path_style"` + // File extensions that support thumbnail generation using native policy API. + ThumbExts []string `json:"thumb_exts,omitempty"` + // Whether to support all file extensions for thumbnail generation. + ThumbSupportAllExts bool `json:"thumb_support_all_exts,omitempty"` + // ThumbMaxSize indicates the maximum allowed size of a thumbnail. 0 indicates that no limit is set. + ThumbMaxSize int64 `json:"thumb_max_size,omitempty"` + // Whether to upload file through server's relay. + Relay bool `json:"relay,omitempty"` + // Whether to pre allocate space for file before upload in physical disk. + PreAllocate bool `json:"pre_allocate,omitempty"` + // MediaMetaExts file extensions that support media meta generation using native policy API. + MediaMetaExts []string `json:"media_meta_exts,omitempty"` + // MediaMetaGeneratorProxy whether to use local proxy to generate media meta. + MediaMetaGeneratorProxy bool `json:"media_meta_generator_proxy,omitempty"` + // ThumbGeneratorProxy whether to use local proxy to generate thumbnail. + ThumbGeneratorProxy bool `json:"thumb_generator_proxy,omitempty"` + // NativeMediaProcessing whether to use native media processing API from storage provider. + NativeMediaProcessing bool `json:"native_media_processing"` + // S3DeleteBatchSize the number of objects to delete in each batch. + S3DeleteBatchSize int `json:"s3_delete_batch_size,omitempty"` + // StreamSaver whether to use stream saver to download file in Web. + StreamSaver bool `json:"stream_saver,omitempty"` + // UseCname whether to use CNAME for endpoint (OSS). + UseCname bool `json:"use_cname,omitempty"` + // CDN domain does not need to be signed. + SourceAuth bool `json:"source_auth,omitempty"` + } + + FileType int + EntityType int + GroupPermission int + FilePermission int + DavAccountOption int + NodeCapability int + + NodeSetting struct { + Provider DownloaderProvider `json:"provider,omitempty"` + *QBittorrentSetting `json:"qbittorrent,omitempty"` + *Aria2Setting `json:"aria2,omitempty"` + // 下载监控间隔 + Interval int `json:"interval,omitempty"` + WaitForSeeding bool `json:"wait_for_seeding,omitempty"` + } + + DownloaderProvider string + + QBittorrentSetting struct { + Server string `json:"server,omitempty"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + Options map[string]any `json:"options,omitempty"` + TempPath string `json:"temp_path,omitempty"` + } + + Aria2Setting struct { + Server string `json:"server,omitempty"` + Token string `json:"token,omitempty"` + Options map[string]any `json:"options,omitempty"` + TempPath string `json:"temp_path,omitempty"` + } + + TaskPublicState struct { + Error string `json:"error,omitempty"` + ErrorHistory []string `json:"error_history,omitempty"` + ExecutedDuration time.Duration `json:"executed_duration,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + ResumeTime int64 `json:"resume_time,omitempty"` + SlaveTaskProps *SlaveTaskProps `json:"slave_task_props,omitempty"` + } + + SlaveTaskProps struct { + NodeID int `json:"node_id,omitempty"` + MasterSiteURl string `json:"master_site_u_rl,omitempty"` + MasterSiteID string `json:"master_site_id,omitempty"` + MasterSiteVersion string `json:"master_site_version,omitempty"` + } + + EntityRecycleOption struct { + UnlinkOnly bool `json:"unlink_only,omitempty"` + } + + DavAccountProps struct { + } + + PolicyType string + + FileProps struct { + } +) + +const ( + GroupPermissionIsAdmin = GroupPermission(iota) + GroupPermissionIsAnonymous + GroupPermissionShare + GroupPermissionWebDAV + GroupPermissionArchiveDownload + GroupPermissionArchiveTask + GroupPermissionWebDAVProxy + GroupPermissionShareDownload + GroupPermission_CommunityPlaceholder1 + GroupPermissionRemoteDownload + GroupPermission_CommunityPlaceholder2 + GroupPermissionRedirectedSource // not used + GroupPermissionAdvanceDelete + GroupPermission_CommunityPlaceholder3 + GroupPermission_CommunityPlaceholder4 + GroupPermissionSetExplicitUser_placeholder + GroupPermissionIgnoreFileOwnership // not used +) + +const ( + NodeCapabilityNone NodeCapability = iota + NodeCapabilityCreateArchive + NodeCapabilityExtractArchive + NodeCapabilityRemoteDownload + NodeCapability_CommunityPlaceholder +) + +const ( + FileTypeFile FileType = iota + FileTypeFolder +) + +const ( + EntityTypeVersion EntityType = iota + EntityTypeThumbnail + EntityTypeLivePhoto +) + +func FileTypeFromString(s string) FileType { + switch s { + case "file": + return FileTypeFile + case "folder": + return FileTypeFolder + } + return -1 +} + +const ( + DavAccountReadOnly DavAccountOption = iota + DavAccountProxy +) + +const ( + PolicyTypeLocal = "local" + PolicyTypeQiniu = "qiniu" + PolicyTypeUpyun = "upyun" + PolicyTypeOss = "oss" + PolicyTypeCos = "cos" + PolicyTypeS3 = "s3" + PolicyTypeOd = "onedrive" + PolicyTypeRemote = "remote" + PolicyTypeObs = "obs" +) + +const ( + DownloaderProviderAria2 = DownloaderProvider("aria2") + DownloaderProviderQBittorrent = DownloaderProvider("qbittorrent") +) diff --git a/inventory/user.go b/inventory/user.go new file mode 100644 index 00000000..1fe16eef --- /dev/null +++ b/inventory/user.go @@ -0,0 +1,594 @@ +package inventory + +import ( + "context" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "hash" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/davaccount" + "github.com/cloudreve/Cloudreve/v4/ent/file" + "github.com/cloudreve/Cloudreve/v4/ent/passkey" + "github.com/cloudreve/Cloudreve/v4/ent/schema" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/go-webauthn/webauthn/webauthn" +) + +type ( + // Ctx keys for eager loading options. + LoadUserGroup struct{} + LoadUserPasskey struct{} + + UserCtx struct{} + UserIDCtx struct{} +) + +var ( + ErrUserEmailExisted = errors.New("user email has been registered") + ErrInactiveUserExisted = errors.New("email already registered but not activated") + ErrorUnknownPasswordType = errors.New("unknown password type") + ErrorIncorrectPassword = errors.New("incorrect password") + ErrInsufficientPoints = errors.New("insufficient points") +) + +type ( + UserClient interface { + TxOperator + // New creates a new user. If user email registered, existed User will be returned. + Create(ctx context.Context, args *NewUserArgs) (*ent.User, error) + // GetByEmail get the user with given email, user status is ignored. + GetByEmail(ctx context.Context, email string) (*ent.User, error) + // GetByID get user by its ID, user status is ignored. + GetByID(ctx context.Context, id int) (*ent.User, error) + // GetActiveByID get user by its ID, only active user will be returned. + GetActiveByID(ctx context.Context, id int) (*ent.User, error) + // SetStatus Set user to given status + SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error) + // AnonymousUser returns the anonymous user. + AnonymousUser(ctx context.Context) (*ent.User, error) + // GetLoginUserByID returns the login user by its ID. It emits some errors and fallback to anonymous user. + GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error) + // GetLoginUserByEmail returns the login user by its WebDAV credentials. + GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error) + // SaveSettings saves user settings. + SaveSettings(ctx context.Context, u *ent.User) error + // SearchActive search active users by Email or nickname. + SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error) + // ApplyStorageDiff apply storage diff to user. + ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error + // UpdateAvatar updates user avatar. + UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error) + // UpdateNickname updates user nickname. + UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error) + // UpdatePassword updates user password. + UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error) + // UpdateTwoFASecret updates user two factor secret. + UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error) + // ListPasskeys list user's passkeys. + ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error) + // AddPasskey add passkey to user. + AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error) + // RemovePasskey remove passkey from user. + RemovePasskey(ctx context.Context, uid int, keyId string) error + // MarkPasskeyUsed updates passkey used at. + MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error + // CountByTimeRange count users by time range. Will return all records if start or end is nil. + CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) + // ListUsers list users with pagination. + ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error) + // Upsert upserts a user. + Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error) + // Delete deletes a user. + Delete(ctx context.Context, uid int) error + // CalculateStorage calculate user's storage from scratch and update user's storage. + CalculateStorage(ctx context.Context, uid int) (int64, error) + } + ListUserParameters struct { + *PaginationArgs + GroupID int + Status user.Status + Nick string + Email string + } + ListUserResult struct { + *PaginationResults + Users []*ent.User + } +) + +func NewUserClient(client *ent.Client) UserClient { + return &userClient{client: client} +} + +type userClient struct { + client *ent.Client +} + +type ( + // NewUserArgs args to create a new user + NewUserArgs struct { + Email string + Nick string // Optional + PlainPassword string + Status user.Status + GroupID int + Avatar string // Optional + Language string // Optional + } + CreateStoragePackArgs struct { + UserID int + Name string + Size int64 + ExpireAt time.Time + } +) + +func (c *userClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) { + if start == nil || end == nil { + return c.client.User.Query().Count(ctx) + } + return c.client.User.Query().Where(user.CreatedAtGTE(*start), user.CreatedAtLT(*end)).Count(ctx) +} + +func (c *userClient) UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error) { + return c.client.User.UpdateOne(u).SetNick(name).Save(ctx) +} + +func (c *userClient) UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error) { + return c.client.User.UpdateOne(u).SetAvatar(avatar).Save(ctx) +} + +func (c *userClient) UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error) { + if secret == "" { + return c.client.User.UpdateOne(u).ClearTwoFactorSecret().Save(ctx) + } + return c.client.User.UpdateOne(u).SetTwoFactorSecret(secret).Save(ctx) +} + +func (c *userClient) UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error) { + digest, err := digestPassword(newPassword) + if err != nil { + return nil, err + } + + return c.client.User.UpdateOne(u).SetPassword(digest).Save(ctx) +} + +func (c *userClient) SetClient(newClient *ent.Client) TxOperator { + return &userClient{client: newClient} +} + +func (c *userClient) GetClient() *ent.Client { + return c.client +} + +func (c *userClient) ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error) { + return c.client.Passkey.Query().Where(passkey.UserID(uid)).All(ctx) +} + +func (c *userClient) AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error) { + return c.client.Passkey.Create(). + SetName(name). + SetCredentialID(base64.StdEncoding.EncodeToString(credential.ID)). + SetUserID(uid). + SetCredential(credential). + Save(ctx) +} + +func (c *userClient) RemovePasskey(ctx context.Context, uid int, keyId string) error { + ctx = schema.SkipSoftDelete(ctx) + _, err := c.client.Passkey.Delete().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).Exec(ctx) + return err +} + +func (c *userClient) MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error { + _, err := c.client.Passkey.Update().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).SetUsedAt(time.Now()).Save(ctx) + return err +} + +func (c *userClient) Delete(ctx context.Context, uid int) error { + // Dav accounts + if _, err := c.client.DavAccount.Delete().Where(davaccount.OwnerID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil { + return fmt.Errorf("failed to delete dav accounts: %w", err) + } + + // Passkeys + if _, err := c.client.Passkey.Delete().Where(passkey.UserID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil { + return fmt.Errorf("failed to delete passkeys: %w", err) + } + + // Tasks + if _, err := c.client.Task.Delete().Where(task.UserTasks(uid)).Exec(ctx); err != nil { + return fmt.Errorf("failed to delete tasks: %w", err) + } + + return c.client.User.DeleteOneID(uid).Exec(schema.SkipSoftDelete(ctx)) +} + +func (c *userClient) ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error { + ae := serializer.NewAggregateError() + for uid, diff := range diffs { + if err := c.client.User.Update().Where(user.ID(uid)).AddStorage(diff).Exec(ctx); err != nil { + ae.Add(fmt.Sprintf("%d", uid), fmt.Errorf("failed to apply storage diff for user %d: %w", uid, err)) + } + } + + return ae.Aggregate() +} + +func (c *userClient) CalculateStorage(ctx context.Context, uid int) (int64, error) { + var sum int64 + batchSize := 5000 + offset := 0 + + for { + allFiles, err := c.client.File.Query(). + Where(file.HasOwnerWith(user.ID(uid))). + WithEntities(). + Offset(offset). + Limit(batchSize). + All(ctx) + if err != nil { + return 0, fmt.Errorf("failed to list user files: %w", err) + } + + if len(allFiles) == 0 { + break + } + + for _, file := range allFiles { + for _, entity := range file.Edges.Entities { + sum += entity.Size + } + } + + offset += batchSize + } + + if _, err := c.client.User.UpdateOneID(uid).SetStorage(sum).Save(ctx); err != nil { + return 0, err + } + + return sum, nil +} + +func (c *userClient) SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error) { + return c.client.User.UpdateOne(u).SetStatus(status).Save(ctx) +} + +func (c *userClient) Create(ctx context.Context, args *NewUserArgs) (*ent.User, error) { + // Try to check if there's user with same email. + if existedUser, err := c.GetByEmail(ctx, args.Email); err == nil { + if existedUser.Status == user.StatusInactive { + return existedUser, ErrInactiveUserExisted + } + return existedUser, ErrUserEmailExisted + } + + nick := args.Nick + if nick == "" { + nick = strings.Split(args.Email, "@")[0] + } + + userSetting := &types.UserSetting{VersionRetention: true, VersionRetentionMax: 10} + query := c.client.User.Create(). + SetEmail(args.Email). + SetNick(nick). + SetStatus(args.Status). + SetGroupID(args.GroupID). + SetAvatar(args.Avatar) + + if args.PlainPassword != "" { + pwdDigest, err := digestPassword(args.PlainPassword) + if err != nil { + return nil, fmt.Errorf("failed to sha256 password: %w", err) + } + query.SetPassword(pwdDigest) + } + + if args.Language != "" { + userSetting.Language = args.Language + } + query.SetSettings(userSetting) + + // Create user + newUser, err := query. + Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create user: %w", err) + } + + if newUser.ID == 1 { + // For the first user registered, elevate it to admin group. + if _, err := newUser.Update().SetGroupID(1).Save(ctx); err != nil { + return newUser, fmt.Errorf("failed to elevate user to admin: %w", err) + } + } + return newUser, nil +} + +func (c *userClient) GetByEmail(ctx context.Context, email string) (*ent.User, error) { + return withUserEagerLoading(ctx, c.client.User.Query().Where(user.EmailEqualFold(email))).First(ctx) +} + +func (c *userClient) GetByID(ctx context.Context, id int) (*ent.User, error) { + return withUserEagerLoading(ctx, c.client.User.Query().Where(user.ID(id))).First(ctx) +} + +func (c *userClient) GetActiveByID(ctx context.Context, id int) (*ent.User, error) { + return withUserEagerLoading( + ctx, + c.client.User.Query(). + Where(user.ID(id)). + Where(user.StatusEQ(user.StatusActive)), + ).First(ctx) +} + +func (c *userClient) GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error) { + ctx = context.WithValue(ctx, LoadUserGroup{}, true) + return withUserEagerLoading( + ctx, + c.client.User.Query(). + Where(user.EmailEqualFold(email)). + Where(user.StatusEQ(user.StatusActive)). + WithDavAccounts(func(q *ent.DavAccountQuery) { + q.Where(davaccount.Password(pwd)) + }), + ).First(ctx) +} + +func (c *userClient) GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error) { + ctx = context.WithValue(ctx, LoadUserGroup{}, true) + if uid > 0 { + expectedUser, err := c.GetActiveByID(ctx, uid) + if err == nil { + return expectedUser, nil + } + + return nil, fmt.Errorf("failed to get user by id: %w", err) + } + + anonymous, err := c.AnonymousUser(ctx) + if err != nil { + return nil, fmt.Errorf("failed to construct anonymous user: %w", err) + } + + return anonymous, nil +} + +func (c *userClient) SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error) { + ctx = context.WithValue(ctx, LoadUserGroup{}, true) + return withUserEagerLoading( + ctx, + c.client.User.Query(). + Where(user.Or(user.EmailContainsFold(keyword), user.NickContainsFold(keyword))). + Limit(limit), + ).All(ctx) +} + +func (c *userClient) SaveSettings(ctx context.Context, u *ent.User) error { + return c.client.User.UpdateOne(u).SetSettings(u.Settings).Exec(ctx) +} + +// UserFromContext get user from context +func UserFromContext(ctx context.Context) *ent.User { + u, _ := ctx.Value(UserCtx{}).(*ent.User) + return u +} + +// UserIDFromContext get user id from context. +func UserIDFromContext(ctx context.Context) int { + uid, ok := ctx.Value(UserIDCtx{}).(int) + if !ok { + u := UserFromContext(ctx) + if u != nil { + uid = u.ID + } + } + + return uid +} + +func (c *userClient) AnonymousUser(ctx context.Context) (*ent.User, error) { + groupClient := NewGroupClient(c.client, "", nil) + anonymousGroup, err := groupClient.AnonymousGroup(ctx) + if err != nil { + return nil, fmt.Errorf("anyonymous group not found: %w", err) + } + + // TODO: save into cache + anonymous := &ent.User{ + Settings: &types.UserSetting{}, + } + anonymous.SetGroup(anonymousGroup) + return anonymous, nil +} + +func (c *userClient) ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error) { + query := c.client.User.Query() + if args.GroupID != 0 { + query = query.Where(user.GroupUsers(args.GroupID)) + } + if args.Status != "" { + query = query.Where(user.StatusEQ(args.Status)) + } + if args.Nick != "" { + query = query.Where(user.NickContainsFold(args.Nick)) + } + if args.Email != "" { + query = query.Where(user.EmailContainsFold(args.Email)) + } + query.Order(getUserOrderOption(args)...) + + // Count total items + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, err + } + + users, err := withUserEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx) + if err != nil { + return nil, err + } + + return &ListUserResult{ + PaginationResults: &PaginationResults{ + TotalItems: total, + Page: args.Page, + PageSize: args.PageSize, + }, + Users: users, + }, nil +} + +func (c *userClient) Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error) { + if u.ID == 0 { + q := c.client.User.Create(). + SetEmail(u.Email). + SetNick(u.Nick). + SetAvatar(u.Avatar). + SetStatus(u.Status). + SetGroupID(u.GroupUsers). + SetPassword(u.Password). + SetSettings(&types.UserSetting{}) + + if password != "" { + pwdDigest, err := digestPassword(password) + if err != nil { + return nil, fmt.Errorf("failed to sha256 password: %w", err) + } + q.SetPassword(pwdDigest) + } + + return q.Save(ctx) + } + + q := c.client.User.UpdateOne(u). + SetEmail(u.Email). + SetNick(u.Nick). + SetAvatar(u.Avatar). + SetStatus(u.Status). + SetGroupID(u.GroupUsers) + + if password != "" { + pwdDigest, err := digestPassword(password) + if err != nil { + return nil, fmt.Errorf("failed to sha256 password: %w", err) + } + q.SetPassword(pwdDigest) + } + + if twoFa != "" { + q.ClearTwoFactorSecret() + } + + return q.Save(ctx) +} + +func getUserOrderOption(args *ListUserParameters) []user.OrderOption { + orderTerm := getOrderTerm(args.Order) + switch args.OrderBy { + case user.FieldNick: + return []user.OrderOption{user.ByNick(orderTerm), user.ByID(orderTerm)} + case user.FieldStorage: + return []user.OrderOption{user.ByStorage(orderTerm), user.ByID(orderTerm)} + case user.FieldEmail: + return []user.OrderOption{user.ByEmail(orderTerm), user.ByID(orderTerm)} + case user.FieldUpdatedAt: + return []user.OrderOption{user.ByUpdatedAt(orderTerm), user.ByID(orderTerm)} + default: + return []user.OrderOption{user.ByID(orderTerm)} + } +} + +// IsAnonymousUser check if given user is anonymous user. +func IsAnonymousUser(u *ent.User) bool { + return u.ID == 0 +} + +// CheckPassword 根据明文校验密码 +func CheckPassword(u *ent.User, password string) error { + // 根据存储密码拆分为 Salt 和 Digest + passwordStore := strings.Split(u.Password, ":") + if len(passwordStore) != 2 && len(passwordStore) != 3 { + return ErrorUnknownPasswordType + } + + // 兼容V2密码,升级后存储格式为: md5:$HASH:$SALT + if len(passwordStore) == 3 { + if passwordStore[0] != "md5" { + return ErrorUnknownPasswordType + } + hash := md5.New() + _, err := hash.Write([]byte(passwordStore[2] + password)) + bs := hex.EncodeToString(hash.Sum(nil)) + if err != nil { + return err + } + if bs != passwordStore[1] { + return ErrorIncorrectPassword + } + } + + //计算 Salt 和密码组合的SHA1摘要 + var hasher hash.Hash + if len(passwordStore[1]) == 64 { + hasher = sha256.New() + } else { + // Compatible with V3 + hasher = sha1.New() + } + + _, err := hasher.Write([]byte(password + passwordStore[0])) + bs := hex.EncodeToString(hasher.Sum(nil)) + if err != nil { + return err + } + + if bs != passwordStore[1] { + return ErrorIncorrectPassword + } + + return nil +} + +func withUserEagerLoading(ctx context.Context, q *ent.UserQuery) *ent.UserQuery { + if v, ok := ctx.Value(LoadUserGroup{}).(bool); ok && v { + q.WithGroup(func(gq *ent.GroupQuery) { + withGroupEagerLoading(ctx, gq) + }) + } + if v, ok := ctx.Value(LoadUserPasskey{}).(bool); ok && v { + q.WithPasskey() + } + return q +} + +func digestPassword(password string) (string, error) { + //生成16位 Salt + salt := util.RandStringRunes(16) + + //计算 Salt 和密码组合的SHA1摘要 + hash := sha256.New() + _, err := hash.Write([]byte(password + salt)) + bs := hex.EncodeToString(hash.Sum(nil)) + + if err != nil { + return "", err + } + + //存储 Salt 值和摘要, ":"分割 + return salt + ":" + string(bs), nil +} diff --git a/main.go b/main.go index 5e214e1a..931c5053 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,10 @@ package main import ( - "context" _ "embed" "flag" - "io/fs" - "net" - "net/http" - "os" - "os/signal" - "path/filepath" - "syscall" - "time" - - "github.com/cloudreve/Cloudreve/v3/bootstrap" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/routers" + "github.com/cloudreve/Cloudreve/v4/cmd" + "github.com/cloudreve/Cloudreve/v4/pkg/util" ) var ( @@ -27,132 +13,24 @@ var ( scriptName string ) -//go:embed assets.zip -var staticZip string - -var staticFS fs.FS - func init() { + flag.BoolVar(&util.UseWorkingDir, "use-working-dir", false, "Use working directory, instead of executable directory") flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "Path to the config file.") - flag.BoolVar(&isEject, "eject", false, "Eject all embedded static files.") flag.StringVar(&scriptName, "database-script", "", "Name of database util script.") - flag.Parse() + //flag.Parse() - staticFS = bootstrap.NewFS(staticZip) - bootstrap.Init(confPath, staticFS) + //staticFS = bootstrap.NewFS(staticZip) + //bootstrap.Init(confPath, staticFS) } func main() { + cmd.Execute() + return // 关闭数据库连接 - defer func() { - if model.DB != nil { - model.DB.Close() - } - }() - - if isEject { - // 开始导出内置静态资源文件 - bootstrap.Eject(staticFS) - return - } - - if scriptName != "" { - // 开始运行助手数据库脚本 - bootstrap.RunScript(scriptName) - return - } - - api := routers.InitRouter() - api.TrustedPlatform = conf.SystemConfig.ProxyHeader - server := &http.Server{Handler: api} - - // 收到信号后关闭服务器 - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go shutdown(sigChan, server) - - defer func() { - <-sigChan - }() - - // 如果启用了SSL - if conf.SSLConfig.CertPath != "" { - util.Log().Info("Listening to %q", conf.SSLConfig.Listen) - server.Addr = conf.SSLConfig.Listen - if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil { - util.Log().Error("Failed to listen to %q: %s", conf.SSLConfig.Listen, err) - return - } - } - - // 如果启用了Unix - if conf.UnixConfig.Listen != "" { - // delete socket file before listening - if _, err := os.Stat(conf.UnixConfig.Listen); err == nil { - if err = os.Remove(conf.UnixConfig.Listen); err != nil { - util.Log().Error("Failed to delete socket file: %s", err) - return - } - } - - util.Log().Info("Listening to %q", conf.UnixConfig.Listen) - if err := RunUnix(server); err != nil { - util.Log().Error("Failed to listen to %q: %s", conf.UnixConfig.Listen, err) - } - return - } - - util.Log().Info("Listening to %q", conf.SystemConfig.Listen) - server.Addr = conf.SystemConfig.Listen - if err := server.ListenAndServe(); err != nil { - util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err) - } -} - -func RunUnix(server *http.Server) error { - listener, err := net.Listen("unix", conf.UnixConfig.Listen) - if err != nil { - return err - } - - defer listener.Close() - defer os.Remove(conf.UnixConfig.Listen) - - if conf.UnixConfig.Perm > 0 { - err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(conf.UnixConfig.Perm)) - if err != nil { - util.Log().Warning( - "Failed to set permission to %q for socket file %q: %s", - conf.UnixConfig.Perm, - conf.UnixConfig.Listen, - err, - ) - } - } - - return server.Serve(listener) -} - -func shutdown(sigChan chan os.Signal, server *http.Server) { - sig := <-sigChan - util.Log().Info("Signal %s received, shutting down server...", sig) - ctx := context.Background() - if conf.SystemConfig.GracePeriod != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second) - defer cancel() - } - - // Shutdown http server - err := server.Shutdown(ctx) - if err != nil { - util.Log().Error("Failed to shutdown server: %s", err) - } - - // Persist in-memory cache - if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil { - util.Log().Warning("Failed to persist cache: %s", err) - } - close(sigChan) + //if scriptName != "" { + // // 开始运行助手数据库脚本 + // bootstrap.RunScript(scriptName) + // return + //} } diff --git a/middleware/auth.go b/middleware/auth.go index 3a7d7635..4df7ed7f 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,24 +1,21 @@ package middleware import ( - "bytes" - "context" - "crypto/md5" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/qiniu/go-sdk/v7/auth/qbox" - "io/ioutil" "net/http" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-contrib/sessions" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" ) @@ -31,14 +28,14 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc { return func(c *gin.Context) { var err error switch c.Request.Method { - case "PUT", "POST", "PATCH": - err = auth.CheckRequest(authInstance, c.Request) + case http.MethodPut, http.MethodPost, http.MethodPatch: + err = auth.CheckRequest(c, authInstance, c.Request) default: - err = auth.CheckURI(authInstance, c.Request.URL) + err = auth.CheckURI(c, authInstance, c.Request.URL) } if err != nil { - c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err)) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err)) c.Abort() return } @@ -50,29 +47,55 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc { // CurrentUser 获取登录用户 func CurrentUser() gin.HandlerFunc { return func(c *gin.Context) { - session := sessions.Default(c) - uid := session.Get("user_id") - if uid != nil { - user, err := model.GetActiveUserByID(uid) - if err == nil { - c.Set("user", &user) - } + dep := dependency.FromContext(c) + shouldContinue, err := dep.TokenAuth().VerifyAndRetrieveUser(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + if shouldContinue { + // TODO: Logto handler } + + uid := inventory.UserIDFromContext(c) + if err := SetUserCtx(c, uid); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + c.Next() } } -// AuthRequired 需要登录 -func AuthRequired() gin.HandlerFunc { +// SetUserCtx set the current login user via uid +func SetUserCtx(c *gin.Context, uid int) error { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + loginUser, err := userClient.GetLoginUserByID(c, uid) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "failed to get login user", err) + } + + SetUserCtxByUser(c, loginUser) + return nil +} + +func SetUserCtxByUser(c *gin.Context, user *ent.User) { + util.WithValue(c, inventory.UserCtx{}, user) +} + +// LoginRequired 需要登录 +func LoginRequired() gin.HandlerFunc { return func(c *gin.Context) { - if user, _ := c.Get("user"); user != nil { - if _, ok := user.(*model.User); ok { - c.Next() - return - } + if u := inventory.UserFromContext(c); u != nil && !inventory.IsAnonymousUser(u) { + c.Next() + return } - c.JSON(200, serializer.CheckLogin()) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCheckLogin, "Login required", nil)) c.Abort() } } @@ -80,60 +103,84 @@ func AuthRequired() gin.HandlerFunc { // WebDAVAuth 验证WebDAV登录及权限 func WebDAVAuth() gin.HandlerFunc { return func(c *gin.Context) { - // OPTIONS 请求不需要鉴权,否则Windows10下无法保存文档 - if c.Request.Method == "OPTIONS" { - c.Next() - return - } - username, password, ok := c.Request.BasicAuth() if !ok { + // OPTIONS 请求不需要鉴权 + if c.Request.Method == http.MethodOptions { + c.Next() + return + } c.Writer.Header()["WWW-Authenticate"] = []string{`Basic realm="cloudreve"`} c.Status(http.StatusUnauthorized) c.Abort() return } - expectedUser, err := model.GetActiveUserByEmail(username) + dep := dependency.FromContext(c) + l := dep.Logger() + userClient := dep.UserClient() + expectedUser, err := userClient.GetActiveByDavAccount(c, username, password) if err != nil { + if username == "" { + if u, err := userClient.GetByEmail(c, username); err == nil { + // Try login with known user but incorrect password, record audit log + SetUserCtxByUser(c, u) + } + } + + l.Debug("WebDAVAuth: failed to get user %q with provided credential: %s", username, err) c.Status(http.StatusUnauthorized) c.Abort() return } - // 密码正确? - webdav, err := model.GetWebdavByPassword(password, expectedUser.ID) - if err != nil { + // Validate dav account + accounts, err := expectedUser.Edges.DavAccountsOrErr() + if err != nil || len(accounts) == 0 { + l.Debug("WebDAVAuth: failed to get user dav accounts %q with provided credential: %s", username, err) c.Status(http.StatusUnauthorized) c.Abort() return } // 用户组已启用WebDAV? - if !expectedUser.Group.WebDAVEnabled { + group, err := expectedUser.Edges.GroupOrErr() + if err != nil { + l.Debug("WebDAVAuth: user group not found: %s", err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + + if !group.Permissions.Enabled(int(types.GroupPermissionWebDAV)) { c.Status(http.StatusForbidden) + l.Debug("WebDAVAuth: user %q does not have WebDAV permission.", expectedUser.Email) c.Abort() return } - // 用户组已启用WebDAV代理? - if !expectedUser.Group.OptionsSerialized.WebDAVProxy { - webdav.UseProxy = false + // 检查是否只读 + if expectedUser.Edges.DavAccounts[0].Options.Enabled(int(types.DavAccountReadOnly)) { + switch c.Request.Method { + case http.MethodDelete, http.MethodPut, "MKCOL", "COPY", "MOVE", "LOCK", "UNLOCK": + c.Status(http.StatusForbidden) + c.Abort() + return + } } - c.Set("user", &expectedUser) - c.Set("webdav", webdav) + SetUserCtxByUser(c, expectedUser) c.Next() } } // 对上传会话进行验证 -func UseUploadSession(policyType string) gin.HandlerFunc { +func UseUploadSession(policyType types.PolicyType) gin.HandlerFunc { return func(c *gin.Context) { // 验证key并查找用户 - resp := uploadCallbackCheck(c, policyType) - if resp.Code != 0 { - c.JSON(CallbackFailedStatusCode, resp) + err := uploadCallbackCheck(c, policyType) + if err != nil { + c.JSON(CallbackFailedStatusCode, serializer.Err(c, err)) c.Abort() return } @@ -143,84 +190,65 @@ func UseUploadSession(policyType string) gin.HandlerFunc { } // uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户 -func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response { +func uploadCallbackCheck(c *gin.Context, policyType types.PolicyType) error { // 验证 Callback Key sessionID := c.Param("sessionID") if sessionID == "" { - return serializer.ParamErr("Session ID cannot be empty", nil) + return serializer.NewError(serializer.CodeParamErr, "Session ID cannot be empty", nil) } - callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID) + dep := dependency.FromContext(c) + callbackSessionRaw, exist := dep.KV().Get(manager.UploadSessionCachePrefix + sessionID) if !exist { - return serializer.Err(serializer.CodeUploadSessionExpired, "上传会话不存在或已过期", nil) + return serializer.NewError(serializer.CodeUploadSessionExpired, "Upload session does not exist or expired", nil) } - callbackSession := callbackSessionRaw.(serializer.UploadSession) - c.Set(filesystem.UploadSessionCtx, &callbackSession) - if callbackSession.Policy.Type != policyType { - return serializer.Err(serializer.CodePolicyNotAllowed, "", nil) + callbackSession := callbackSessionRaw.(fs.UploadSession) + c.Set(manager.UploadSessionCtx, &callbackSession) + if callbackSession.Policy.Type != string(policyType) { + return serializer.NewError(serializer.CodePolicyNotAllowed, "", nil) } - // 清理回调会话 - _ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix) - - // 查找用户 - user, err := model.GetActiveUserByID(callbackSession.UID) - if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) + if err := SetUserCtx(c, callbackSession.UID); err != nil { + return err } - c.Set(filesystem.UserCtx, &user) - return serializer.Response{} + + return nil } // RemoteCallbackAuth 远程回调签名验证 func RemoteCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { // 验证签名 - session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)} - if err := auth.CheckRequest(authInstance, c.Request); err != nil { - c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err)) - c.Abort() - return - } - - c.Next() - - } -} - -// QiniuCallbackAuth 七牛回调签名验证 -func QiniuCallbackAuth() gin.HandlerFunc { - return func(c *gin.Context) { - session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 验证回调是否来自qiniu - mac := qbox.NewMac(session.Policy.AccessKey, session.Policy.SecretKey) - ok, err := mac.VerifyCallback(c.Request) - if err != nil { - util.Log().Debug("Failed to verify callback request: %s", err) - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."}) + session := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession) + if session.Policy.Edges.Node == nil { + c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, "Node not found", nil)) c.Abort() return } - if !ok { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Invalid signature."}) + authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.Edges.Node.SlaveKey)} + if err := auth.CheckRequest(c, authInstance, c.Request); err != nil { + c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err)) c.Abort() return } c.Next() + } } // OSSCallbackAuth 阿里云OSS回调签名验证 func OSSCallbackAuth() gin.HandlerFunc { return func(c *gin.Context) { - err := oss.VerifyCallbackSignature(c.Request) + dep := dependency.FromContext(c) + err := oss.VerifyCallbackSignature(c.Request, dep.KV(), dep.RequestClient( + request.WithContext(c), + request.WithLogger(logging.FromContext(c)), + )) if err != nil { - util.Log().Debug("Failed to verify callback request: %s", err) + dep.Logger().Debug("Failed to verify callback request: %s", err) c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."}) c.Abort() return @@ -230,71 +258,12 @@ func OSSCallbackAuth() gin.HandlerFunc { } } -// UpyunCallbackAuth 又拍云回调签名验证 -func UpyunCallbackAuth() gin.HandlerFunc { - return func(c *gin.Context) { - session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 获取请求正文 - body, err := ioutil.ReadAll(c.Request.Body) - c.Request.Body.Close() - if err != nil { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()}) - c.Abort() - return - } - - c.Request.Body = ioutil.NopCloser(bytes.NewReader(body)) - - // 准备验证Upyun回调签名 - handler := upyun.Driver{Policy: &session.Policy} - contentMD5 := c.Request.Header.Get("Content-Md5") - date := c.Request.Header.Get("Date") - actualSignature := c.Request.Header.Get("Authorization") - - // 计算正文MD5 - actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body)) - if actualContentMD5 != contentMD5 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5 mismatch."}) - c.Abort() - return - } - - // 计算理论签名 - signature := handler.Sign(context.Background(), []string{ - "POST", - c.Request.URL.Path, - date, - contentMD5, - }) - - // 对比签名 - if signature != actualSignature { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Signature not match"}) - c.Abort() - return - } - - c.Next() - } -} - -// OneDriveCallbackAuth OneDrive回调签名验证 -func OneDriveCallbackAuth() gin.HandlerFunc { - return func(c *gin.Context) { - // 发送回调结束信号 - mq.GlobalMQ.Publish(c.Param("sessionID"), mq.Message{}) - - c.Next() - } -} - // IsAdmin 必须为管理员用户组 func IsAdmin() gin.HandlerFunc { return func(c *gin.Context) { - user, _ := c.Get("user") - if user.(*model.User).Group.ID != 1 && user.(*model.User).ID != 1 { - c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "", nil)) + user := inventory.UserFromContext(c) + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) { + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeNoPermissionErr, "", nil)) c.Abort() return } diff --git a/middleware/auth_test.go b/middleware/auth_test.go deleted file mode 100644 index 9e8650fe..00000000 --- a/middleware/auth_test.go +++ /dev/null @@ -1,605 +0,0 @@ -package middleware - -import ( - "database/sql" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/qiniu/go-sdk/v7/auth/qbox" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestCurrentUser(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - - //session为空 - sessionFunc := Session("233") - sessionFunc(c) - CurrentUser()(c) - user, _ := c.Get("user") - asserts.Nil(user) - - //session正确 - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - util.SetSession(c, map[string]interface{}{"user_id": 1}) - rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}). - AddRow(1, nil, "admin@cloudreve.org", "{}") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows) - CurrentUser()(c) - user, _ = c.Get("user") - asserts.NotNil(user) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestAuthRequired(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - AuthRequiredFunc := AuthRequired() - - // 未登录 - AuthRequiredFunc(c) - asserts.NotNil(c) - - // 类型错误 - c.Set("user", 123) - AuthRequiredFunc(c) - asserts.NotNil(c) - - // 正常 - c.Set("user", &model.User{}) - AuthRequiredFunc(c) - asserts.NotNil(c) -} - -func TestSignRequired(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - SignRequiredFunc := SignRequired(authInstance) - - // 鉴权失败 - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("PUT", "/test", nil) - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.True(c.IsAborted()) - - // Sign verify success - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("PUT", "/test", nil) - c.Request = auth.SignRequest(authInstance, c.Request, 0) - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.False(c.IsAborted()) -} - -func TestWebDAVAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := WebDAVAuth() - - // options请求跳过验证 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("OPTIONS", "/test", nil) - AuthFunc(c) - } - - // 请求HTTP Basic Auth - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - AuthFunc(c) - asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"]) - } - - // 用户名不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "password", "email"}), - ) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) - } - - // 密码错误 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"), - ) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) - } - - //未启用 WebDAV - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "password", "email", "group_id", "options"}). - AddRow(1, - "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", - "who@cloudreve.org", - 1, - "{}", - ), - ) - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false)) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusForbidden) - } - - //正常 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "password", "email", "group_id", "options"}). - AddRow(1, - "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", - "who@cloudreve.org", - 1, - "{}", - ), - ) - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true)) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), 200) - _, ok := c.Get("user") - asserts.True(ok) - } - -} - -func TestUseUploadSession(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := UseUploadSession("local") - - // sessionID 为空 - { - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - cache.Set( - filesystem.UploadSessionCachePrefix+"testCallBackRemote", - serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{Type: "local"}, - }, - 0, - ) - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]")) - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testCallBackRemote"}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } -} - -func TestUploadCallbackCheck(t *testing.T) { - a := assert.New(t) - rec := httptest.NewRecorder() - - // 上传会话不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testSessionNotExist"}, - } - res := uploadCallbackCheck(c, "local") - a.Contains("上传会话不存在或已过期", res.Msg) - } - - // 上传策略不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testPolicyNotMatch"}, - } - cache.Set( - filesystem.UploadSessionCachePrefix+"testPolicyNotMatch", - serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{Type: "remote"}, - }, - 0, - ) - res := uploadCallbackCheck(c, "local") - a.Contains("Policy not supported", res.Msg) - } - - // 用户不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testUserNotExist"}, - } - cache.Set( - filesystem.UploadSessionCachePrefix+"testUserNotExist", - serializer.UploadSession{ - UID: 313, - VirtualPath: "/", - Policy: model.Policy{Type: "remote"}, - }, - 0, - ) - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"})) - res := uploadCallbackCheck(c, "remote") - a.Contains("找不到用户", res.Msg) - a.NoError(mock.ExpectationsWereMet()) - _, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist") - a.False(ok) - } -} - -func TestRemoteCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := RemoteCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{SecretKey: "123"}, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.False(c.IsAborted()) - } - - // 签名错误 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{SecretKey: "123"}, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - AuthFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestQiniuCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := QiniuCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil) - mac := qbox.NewMac("123", "123") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } - - // 验证失败 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil) - mac := qbox.NewMac("123", "1213") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestOSSCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := OSSCallbackAuth() - - // 签名验证失败 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil) - mac := qbox.NewMac("123", "123") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`))) - c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="} - c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="} - AuthFunc(c) - asserts.False(c.IsAborted()) - } - -} - -type fakeRead string - -func (r fakeRead) Read(p []byte) (int, error) { - return 0, errors.New("error") -} - -func TestUpyunCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := UpyunCallbackAuth() - - // 无法获取请求正文 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead(""))) - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 正文MD5不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"123"} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 签名不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"} - c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="} - AuthFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestOneDriveCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := OneDriveCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "TestOneDriveCallbackAuth"}, - } - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1"))) - res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1) - AuthFunc(c) - select { - case <-res: - case <-time.After(time.Millisecond * 500): - asserts.Fail("mq message should be published") - } - asserts.False(c.IsAborted()) - } -} - -func TestIsAdmin(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := IsAdmin() - - // 非管理员 - { - c, _ := gin.CreateTestContext(rec) - c.Set("user", &model.User{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 是管理员 - { - c, _ := gin.CreateTestContext(rec) - user := &model.User{} - user.Group.ID = 1 - c.Set("user", user) - testFunc(c) - asserts.False(c.IsAborted()) - } - - // 初始用户,非管理组 - { - c, _ := gin.CreateTestContext(rec) - user := &model.User{} - user.Group.ID = 2 - user.ID = 1 - c.Set("user", user) - testFunc(c) - asserts.False(c.IsAborted()) - } -} diff --git a/middleware/captcha.go b/middleware/captcha.go index baf24a50..b97c9ba5 100644 --- a/middleware/captcha.go +++ b/middleware/captcha.go @@ -3,52 +3,56 @@ package middleware import ( "bytes" "encoding/json" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/recaptcha" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/recaptcha" + request2 "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" "github.com/gin-gonic/gin" "github.com/mojocn/base64Captcha" - captcha "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha/v20190722" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" "io" - "io/ioutil" - "strconv" + "net/http" + "net/url" + "strings" "time" ) type req struct { - CaptchaCode string `json:"captchaCode"` - Ticket string `json:"ticket"` - Randstr string `json:"randstr"` + Captcha string `json:"captcha"` + Ticket string `json:"ticket"` + Randstr string `json:"randstr"` } const ( captchaNotMatch = "CAPTCHA not match." captchaRefresh = "Verification failed, please refresh the page and retry." + + tcCaptchaEndpoint = "captcha.tencentcloudapi.com" + turnstileEndpoint = "https://challenges.cloudflare.com/turnstile/v0/siteverify" +) + +// CaptchaIDCtx defines keys for captcha ID +type ( + CaptchaIDCtx struct{} + turnstileResponse struct { + Success bool `json:"success"` + } ) // CaptchaRequired 验证请求签名 -func CaptchaRequired(configName string) gin.HandlerFunc { +func CaptchaRequired(enabled func(c *gin.Context) bool) gin.HandlerFunc { return func(c *gin.Context) { - // 相关设定 - options := model.GetSettingByNames(configName, - "captcha_type", - "captcha_ReCaptchaSecret", - "captcha_TCaptcha_SecretId", - "captcha_TCaptcha_SecretKey", - "captcha_TCaptcha_CaptchaAppId", - "captcha_TCaptcha_AppSecretKey") - // 检查验证码 - isCaptchaRequired := model.IsTrueVal(options[configName]) + if enabled(c) { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + l := logging.FromContext(c) - if isCaptchaRequired { var service req bodyCopy := new(bytes.Buffer) _, err := io.Copy(bodyCopy, c.Request.Body) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err)) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err)) c.Abort() return } @@ -56,65 +60,69 @@ func CaptchaRequired(configName string) gin.HandlerFunc { bodyData := bodyCopy.Bytes() err = json.Unmarshal(bodyData, &service) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err)) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err)) c.Abort() return } - c.Request.Body = ioutil.NopCloser(bytes.NewReader(bodyData)) - switch options["captcha_type"] { - case "normal": - captchaID := util.GetSession(c, "captchaID") - util.DeleteSession(c, "captchaID") - if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) { - c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err)) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyData)) + switch settings.CaptchaType(c) { + case setting.CaptchaNormal, setting.CaptchaTcaptcha: + if service.Ticket == "" || !base64Captcha.VerifyCaptcha(service.Ticket, service.Captcha) { + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err)) c.Abort() return } break - case "recaptcha": - reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second) + case setting.CaptchaReCaptcha: + captchaSetting := settings.ReCaptcha(c) + reCAPTCHA, err := recaptcha.NewReCAPTCHA(captchaSetting.Secret, recaptcha.V2, 10*time.Second) if err != nil { - util.Log().Warning("reCAPTCHA verification failed, %s", err) + l.Warning("reCAPTCHA verification failed, %s", err) c.Abort() break } - err = reCAPTCHA.Verify(service.CaptchaCode) + err = reCAPTCHA.Verify(service.Captcha) if err != nil { - util.Log().Warning("reCAPTCHA verification failed, %s", err) - c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil)) + l.Warning("reCAPTCHA verification failed, %s", err) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaRefresh, err)) c.Abort() return } break - case "tcaptcha": - credential := common.NewCredential( - options["captcha_TCaptcha_SecretId"], - options["captcha_TCaptcha_SecretKey"], + case setting.CaptchaTurnstile: + captchaSetting := settings.TurnstileCaptcha(c) + r := dep.RequestClient( + request2.WithContext(c), + request2.WithLogger(logging.FromContext(c)), + request2.WithHeader(http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}), ) - cpf := profile.NewClientProfile() - cpf.HttpProfile.Endpoint = "captcha.tencentcloudapi.com" - client, _ := captcha.NewClient(credential, "", cpf) - request := captcha.NewDescribeCaptchaResultRequest() - request.CaptchaType = common.Uint64Ptr(9) - appid, _ := strconv.Atoi(options["captcha_TCaptcha_CaptchaAppId"]) - request.CaptchaAppId = common.Uint64Ptr(uint64(appid)) - request.AppSecretKey = common.StringPtr(options["captcha_TCaptcha_AppSecretKey"]) - request.Ticket = common.StringPtr(service.Ticket) - request.Randstr = common.StringPtr(service.Randstr) - request.UserIp = common.StringPtr(c.ClientIP()) - response, err := client.DescribeCaptchaResult(request) + formData := url.Values{} + formData.Set("secret", captchaSetting.Secret) + formData.Set("response", service.Ticket) + res, err := r.Request("POST", turnstileEndpoint, strings.NewReader(formData.Encode())). + CheckHTTPResponse(http.StatusOK). + GetResponse() if err != nil { - util.Log().Warning("TCaptcha verification failed, %s", err) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err)) c.Abort() - break + return + } + + var trunstileRes turnstileResponse + err = json.Unmarshal([]byte(res), &trunstileRes) + if err != nil { + l.Warning("Turnstile verification failed, %s", err) + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err)) + c.Abort() + return } - if *response.Response.CaptchaCode != int64(1) { - c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil)) + if !trunstileRes.Success { + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err)) c.Abort() return } diff --git a/middleware/captcha_test.go b/middleware/captcha_test.go deleted file mode 100644 index 1846d31c..00000000 --- a/middleware/captcha_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package middleware - -import ( - "bytes" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "testing" -) - -type errReader int - -func (errReader) Read(p []byte) (n int, err error) { - return 0, errors.New("test error") -} - -func TestCaptchaRequired_General(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 未启用验证码 - { - cache.SetSettings(map[string]string{ - "login_captcha": "0", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // body 无法读取 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", errReader(1)) - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // body JSON 解析失败 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("123")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Normal(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 验证码错误 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "normal", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - Session("233")(c) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Recaptcha(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 无法初始化reCaptcha实例 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "recaptcha", - "captcha_ReCaptchaSecret": "", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // 验证码错误 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "recaptcha", - "captcha_ReCaptchaSecret": "233", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Tcaptcha(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 验证出错 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "tcaptcha", - "captcha_ReCaptchaSecret": "", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} diff --git a/middleware/cluster.go b/middleware/cluster.go index 2e814bea..3557f440 100644 --- a/middleware/cluster.go +++ b/middleware/cluster.go @@ -1,62 +1,75 @@ package middleware import ( - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "sync" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/routers/controllers" "github.com/gin-gonic/gin" - "strconv" ) -// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据 -func MasterMetadata() gin.HandlerFunc { - return func(c *gin.Context) { - c.Set("MasterSiteID", c.GetHeader(auth.CrHeaderPrefix+"Site-Id")) - c.Set("MasterSiteURL", c.GetHeader(auth.CrHeaderPrefix+"Site-Url")) - c.Set("MasterVersion", c.GetHeader(auth.CrHeaderPrefix+"Cloudreve-Version")) - c.Next() - } +type SlaveNodeSettingGetter interface { + // GetNodeSetting returns the node settings and its hash + GetNodeSetting() (*types.NodeSetting, string) } -// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例 -func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc { +var downloaderPool = sync.Map{} + +// PrepareSlaveDownloader creates or resume a downloader based on input node settings +func PrepareSlaveDownloader(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc { return func(c *gin.Context) { - if siteID, exist := c.Get("MasterSiteID"); exist { - // 获取对应主机节点的从机Aria2实例 - caller, err := clusterController.GetAria2Instance(siteID.(string)) - if err != nil { - c.JSON(200, serializer.Err(serializer.CodeNotSet, "Failed to get Aria2 instance", err)) - c.Abort() - return - } - - c.Set("MasterAria2Instance", caller) + nodeSettings, hash := controllers.ParametersFromContext[SlaveNodeSettingGetter](c, ctxKey).GetNodeSetting() + + // try to get downloader from pool + if d, ok := downloaderPool.Load(hash); ok { + c.Set(downloader.DownloaderCtxKey, d) c.Next() return } - c.JSON(200, serializer.ParamErr("Unknown master node ID", nil)) - c.Abort() + // create a new downloader + d, err := cluster.NewDownloader(c, dep.RequestClient(request.WithContext(c), request.WithLogger(dep.Logger())), dep.SettingProvider(), nodeSettings) + if err != nil { + c.JSON(200, serializer.ParamErr(c, "Failed to create downloader", err)) + c.Abort() + return + } + + // save downloader to pool + downloaderPool.Store(hash, d) + c.Set(downloader.DownloaderCtxKey, d) + c.Next() } } -func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc { +func SlaveRPCSignRequired() gin.HandlerFunc { return func(c *gin.Context) { - nodeID, err := strconv.ParseUint(c.GetHeader(auth.CrHeaderPrefix+"Node-Id"), 10, 64) - if err != nil { - c.JSON(200, serializer.ParamErr("Unknown master node ID", err)) + nodeId := cluster.NodeIdFromContext(c) + if nodeId == 0 { + c.JSON(200, serializer.ParamErr(c, "Unknown node ID", nil)) c.Abort() return } - slaveNode := nodePool.GetNodeByID(uint(nodeID)) - if slaveNode == nil { - c.JSON(200, serializer.ParamErr("Unknown master node ID", err)) + np, err := dependency.FromContext(c).NodePool(c) + if err != nil { + c.JSON(200, serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err)) c.Abort() return } - SignRequired(slaveNode.MasterAuthInstance())(c) + slaveNode, err := np.Get(c, types.NodeCapabilityNone, nodeId) + if slaveNode == nil || slaveNode.IsMaster() { + c.JSON(200, serializer.ParamErr(c, "Unknown node ID", err)) + c.Abort() + return + } + SignRequired(slaveNode.AuthInstance())(c) } } diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go deleted file mode 100644 index 440163d7..00000000 --- a/middleware/cluster_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package middleware - -import ( - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestMasterMetadata(t *testing.T) { - a := assert.New(t) - masterMetaDataFunc := MasterMetadata() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - - c.Request.Header = map[string][]string{ - "X-Cr-Site-Id": {"expectedSiteID"}, - "X-Cr-Site-Url": {"expectedSiteURL"}, - "X-Cr-Cloudreve-Version": {"expectedMasterVersion"}, - } - masterMetaDataFunc(c) - siteID, _ := c.Get("MasterSiteID") - siteURL, _ := c.Get("MasterSiteURL") - siteVersion, _ := c.Get("MasterVersion") - - a.Equal("expectedSiteID", siteID.(string)) - a.Equal("expectedSiteURL", siteURL.(string)) - a.Equal("expectedMasterVersion", siteVersion.(string)) -} - -func TestSlaveRPCSignRequired(t *testing.T) { - a := assert.New(t) - np := &cluster.NodePool{} - np.Init() - slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np) - rec := httptest.NewRecorder() - - // id parse failed - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "unknown") - slaveRPCSignRequiredFunc(c) - a.True(c.IsAborted()) - } - - // node id not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "38") - slaveRPCSignRequiredFunc(c) - a.True(c.IsAborted()) - } - - // success - { - authInstance := auth.HMACAuth{SecretKey: []byte("")} - np.Add(&model.Node{Model: gorm.Model{ - ID: 38, - }}) - - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("POST", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "38") - c.Request = auth.SignRequest(authInstance, c.Request, 0) - slaveRPCSignRequiredFunc(c) - a.False(c.IsAborted()) - } -} - -func TestUseSlaveAria2Instance(t *testing.T) { - a := assert.New(t) - - // MasterSiteID not set - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - useSlaveAria2InstanceFunc(c) - a.True(c.IsAborted()) - } - - // Cannot get aria2 instances - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("MasterSiteID", "expectedSiteID") - testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error")) - useSlaveAria2InstanceFunc(c) - a.True(c.IsAborted()) - testController.AssertExpectations(t) - } - - // Success - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("MasterSiteID", "expectedSiteID") - testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil) - useSlaveAria2InstanceFunc(c) - a.False(c.IsAborted()) - res, _ := c.Get("MasterAria2Instance") - a.NotNil(res) - testController.AssertExpectations(t) - } -} diff --git a/middleware/common.go b/middleware/common.go index 9b2cb08d..21833bfd 100644 --- a/middleware/common.go +++ b/middleware/common.go @@ -1,26 +1,35 @@ package middleware import ( + "context" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/auth/requestinfo" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" "net/http" + "time" ) // HashID 将给定对象的HashID转换为真实ID func HashID(IDType int) gin.HandlerFunc { return func(c *gin.Context) { + dep := dependency.FromContext(c) if c.Param("id") != "" { - id, err := hashid.DecodeHashID(c.Param("id"), IDType) + id, err := dep.HashIDEncoder().Decode(c.Param("id"), IDType) if err == nil { - c.Set("object_id", id) + util.WithValue(c, hashid.ObjectIDCtx{}, id) c.Next() return } - c.JSON(200, serializer.ParamErr("Failed to parse object ID", nil)) + c.JSON(200, serializer.ParamErr(c, "Failed to parse object ID", err)) c.Abort() return @@ -30,10 +39,10 @@ func HashID(IDType int) gin.HandlerFunc { } // IsFunctionEnabled 当功能未开启时阻止访问 -func IsFunctionEnabled(key string) gin.HandlerFunc { +func IsFunctionEnabled(check func(c *gin.Context) bool) gin.HandlerFunc { return func(c *gin.Context) { - if !model.IsTrueVal(model.GetSettingByName(key)) { - c.JSON(200, serializer.Err(serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil)) + if !check(c) { + c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil)) c.Abort() return } @@ -56,9 +65,10 @@ func Sandbox() gin.HandlerFunc { } // StaticResourceCache 使用静态资源缓存策略 -func StaticResourceCache() gin.HandlerFunc { +func StaticResourceCache(dep dependency.Dep) gin.HandlerFunc { + settings := dep.SettingProvider() return func(c *gin.Context) { - c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400))) + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", settings.PublicResourceMaxAge(c))) } } @@ -66,8 +76,9 @@ func StaticResourceCache() gin.HandlerFunc { // MobileRequestOnly func MobileRequestOnly() gin.HandlerFunc { return func(c *gin.Context) { - if c.GetHeader(auth.CrHeaderPrefix+"ios") == "" { - c.Redirect(http.StatusMovedPermanently, model.GetSiteURL().String()) + dep := dependency.FromContext(c) + if c.GetHeader(constants.CrHeaderPrefix+"ios") == "" { + c.Redirect(http.StatusMovedPermanently, dep.SettingProvider().SiteURL(c).String()) c.Abort() return } @@ -75,3 +86,66 @@ func MobileRequestOnly() gin.HandlerFunc { c.Next() } } + +// InitializeHandling is added at the beginning of handler chain, it did following setups: +// 1. Inject dependency manager into request context +// 2. Generate and inject correlation ID for diagnostic. +func InitializeHandling(dep dependency.Dep) gin.HandlerFunc { + return func(c *gin.Context) { + reqInfo := &requestinfo.RequestInfo{ + IP: c.ClientIP(), + Host: c.Request.Host, + UserAgent: c.Request.UserAgent(), + } + cid := uuid.FromStringOrNil(c.GetHeader(request.CorrelationHeader)) + if cid == uuid.Nil { + cid = uuid.Must(uuid.NewV4()) + } + + l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s]", cid)) + ctx := dep.ForkWithLogger(c.Request.Context(), l) + ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid) + ctx = context.WithValue(ctx, requestinfo.RequestInfoCtx{}, reqInfo) + ctx = context.WithValue(ctx, logging.LoggerCtx{}, l) + if id := c.Param("nodeId"); id != "" { + ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, id) + } else { + ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, c.GetHeader(request.SlaveNodeIDHeader)) + } + c.Request = c.Request.WithContext(ctx) + + c.Next() + } +} + +// InitializeHandlingSlave retrieves coll correlation ID and other metadata from request header +func InitializeHandlingSlave() gin.HandlerFunc { + return func(c *gin.Context) { + ctx := context.WithValue(c.Request.Context(), cluster.MasterSiteIDCtx{}, c.GetHeader(request.SiteIDHeader)) + ctx = context.WithValue(ctx, cluster.MasterSiteUrlCtx{}, c.GetHeader(request.SiteURLHeader)) + ctx = context.WithValue(ctx, cluster.MasterSiteVersionCtx{}, c.GetHeader(request.SiteVersionHeader)) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} + +// Logging logs incoming request info +func Logging() gin.HandlerFunc { + return func(c *gin.Context) { + // Start timer + start := time.Now() + path := c.Request.URL.Path + raw := c.Request.URL.RawQuery + + // Process request + c.Next() + + if raw != "" { + path = path + "?" + raw + } + + l := logging.FromContext(c) + logging.Request(l, true, c.Writer.Status(), c.Request.Method, c.ClientIP(), path, + c.Errors.ByType(gin.ErrorTypePrivate).String(), start) + } +} diff --git a/middleware/common_test.go b/middleware/common_test.go deleted file mode 100644 index 1ab839a8..00000000 --- a/middleware/common_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestHashID(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - TestFunc := HashID(hashid.FolderID) - - // 未给定ID对象,跳过 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } - - // 给定ID,解析失败 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "2333"}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(c.IsAborted()) - } - - // 给定ID,解析成功 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", hashid.HashID(1, hashid.FolderID)}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } -} - -func TestIsFunctionEnabled(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - TestFunc := IsFunctionEnabled("TestIsFunctionEnabled") - - // 未开启 - { - cache.Set("setting_TestIsFunctionEnabled", "0", 0) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.True(c.IsAborted()) - } - // 开启 - { - cache.Set("setting_TestIsFunctionEnabled", "1", 0) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - -} - -func TestCacheControl(t *testing.T) { - a := assert.New(t) - TestFunc := CacheControl() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache") -} - -func TestSandbox(t *testing.T) { - a := assert.New(t) - TestFunc := Sandbox() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox") -} - -func TestStaticResourceCache(t *testing.T) { - a := assert.New(t) - TestFunc := StaticResourceCache() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age") -} diff --git a/middleware/file.go b/middleware/file.go index 995637e6..423ee3c3 100644 --- a/middleware/file.go +++ b/middleware/file.go @@ -1,30 +1,49 @@ package middleware import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/routers/controllers" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" ) -// ValidateSourceLink validates if the perm source link is a valid redirect link -func ValidateSourceLink() gin.HandlerFunc { +// UrisService is a wrapper for service supports batch file operations +type UrisService interface { + GetUris() []string +} + +// ValidateBatchFileCount validates if the batch file count is within the limit +func ValidateBatchFileCount(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc { + settings := dep.SettingProvider() return func(c *gin.Context) { - linkID, ok := c.Get("object_id") - if !ok { - c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) + uris := controllers.ParametersFromContext[UrisService](c, ctxKey) + limit := settings.MaxBatchedFile(c) + if len((uris).GetUris()) > limit { + c.JSON(200, serializer.ErrWithDetails( + c, + serializer.CodeTooManyUris, + fmt.Sprintf("Maximum allowed batch size: %d", limit), + nil, + )) c.Abort() return } - sourceLink, err := model.GetSourceLinkByID(linkID) - if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") { - c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) - c.Abort() - return + c.Next() + } +} + +// ContextHint parses the context hint header and set it to context +func ContextHint() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader(dbfs.ContextHintHeader) != "" { + util.WithValue(c, dbfs.ContextHintCtxKey{}, uuid.FromStringOrNil(c.GetHeader(dbfs.ContextHintHeader))) } - sourceLink.Downloaded() - c.Set("source_link", sourceLink) c.Next() } } diff --git a/middleware/file_test.go b/middleware/file_test.go deleted file mode 100644 index 5ca4014a..00000000 --- a/middleware/file_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package middleware - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestValidateSourceLink(t *testing.T) { - a := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ValidateSourceLink() - - // ID 不存在 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - a.True(c.IsAborted()) - } - - // SourceLink 不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - testFunc(c) - a.True(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - - // 原文件不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"})) - testFunc(c) - a.True(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1)) - testFunc(c) - a.False(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - -} diff --git a/middleware/frontend.go b/middleware/frontend.go index f07d9b66..b1f10d69 100644 --- a/middleware/frontend.go +++ b/middleware/frontend.go @@ -1,73 +1,82 @@ package middleware import ( - "github.com/cloudreve/Cloudreve/v3/bootstrap" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" - "io/ioutil" + "io" "net/http" "strings" ) // FrontendFileHandler 前端静态文件处理 -func FrontendFileHandler() gin.HandlerFunc { +func FrontendFileHandler(dep dependency.Dep) gin.HandlerFunc { + fs := dep.ServerStaticFS() + l := dep.Logger() + ignoreFunc := func(c *gin.Context) { c.Next() } - if bootstrap.StaticFS == nil { + if fs == nil { return ignoreFunc } // 读取index.html - file, err := bootstrap.StaticFS.Open("/index.html") + file, err := fs.Open("/index.html") if err != nil { - util.Log().Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.") + l.Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.") return ignoreFunc } - fileContentBytes, err := ioutil.ReadAll(file) + fileContentBytes, err := io.ReadAll(file) if err != nil { - util.Log().Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.") + l.Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.") return ignoreFunc } fileContent := string(fileContentBytes) - fileServer := http.FileServer(bootstrap.StaticFS) + fileServer := http.FileServer(fs) return func(c *gin.Context) { path := c.Request.URL.Path - // API 跳过 + // Skipping routers handled by backend if strings.HasPrefix(path, "/api") || - strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || - strings.HasPrefix(path, "/f") || + strings.HasPrefix(path, "/f/") || + strings.HasPrefix(path, "/s/") || path == "/manifest.json" { c.Next() return } // 不存在的路径和index.html均返回index.html - if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) { + if (path == "/index.html") || (path == "/") || !fs.Exists("/", path) { // 读取、替换站点设置 - options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript", - "pwa_small_icon") + settingClient := dep.SettingProvider() + siteBasic := settingClient.SiteBasic(c) + pwaOpts := settingClient.PWA(c) + theme := settingClient.Theme(c) finalHTML := util.Replace(map[string]string{ - "{siteName}": options["siteName"], - "{siteDes}": options["siteDes"], - "{siteScript}": options["siteScript"], - "{pwa_small_icon}": options["pwa_small_icon"], + "{siteName}": siteBasic.Name, + "{siteDes}": siteBasic.Description, + "{siteScript}": siteBasic.Script, + "{pwa_small_icon}": pwaOpts.SmallIcon, + "{pwa_medium_icon}": pwaOpts.MediumIcon, + "var(--defaultThemeColor)": theme.DefaultTheme, }, fileContent) c.Header("Content-Type", "text/html") + c.Header("Cache-Control", "public, no-cache") c.String(200, finalHTML) c.Abort() return } - if path == "/service-worker.js" { + if path == "/sw.js" || strings.HasPrefix(path, "/locales/") { c.Header("Cache-Control", "public, no-cache") + } else if strings.HasPrefix(path, "/assets/") { + c.Header("Cache-Control", "public, max-age=31536000") } // 存在的静态文件 diff --git a/middleware/frontend_test.go b/middleware/frontend_test.go deleted file mode 100644 index d32529db..00000000 --- a/middleware/frontend_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package middleware - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/bootstrap" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -type StaticMock struct { - testMock.Mock -} - -func (m StaticMock) Open(name string) (http.File, error) { - args := m.Called(name) - return args.Get(0).(http.File), args.Error(1) -} - -func (m StaticMock) Exists(prefix string, filepath string) bool { - args := m.Called(prefix, filepath) - return args.Bool(0) -} - -func TestFrontendFileHandler(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 静态资源未加载 - { - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // index.html 不存在 - { - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(&os.File{}, errors.New("error")) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // index.html 读取失败 - { - file, _ := util.CreatNestedFile("tests/index.html") - file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // 成功且命中 - { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - - cache.Set("setting_siteName", "cloudreve", 0) - cache.Set("setting_siteKeywords", "cloudreve", 0) - cache.Set("setting_siteScript", "cloudreve", 0) - cache.Set("setting_pwa_small_icon", "cloudreve", 0) - - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功且命中静态文件 - { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - testStatic.On("Exists", "/", "/2"). - Return(true) - testStatic.On("Open", "/2"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/2", nil) - - TestFunc(c) - asserts.True(c.IsAborted()) - testStatic.AssertExpectations(t) - } - - // API 相关跳过 - { - for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", reqPath, nil) - - TestFunc(c) - asserts.False(c.IsAborted()) - } - } - -} diff --git a/middleware/mock.go b/middleware/mock.go index d026e77a..b0733801 100644 --- a/middleware/mock.go +++ b/middleware/mock.go @@ -1,7 +1,7 @@ package middleware import ( - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" ) diff --git a/middleware/mock_test.go b/middleware/mock_test.go deleted file mode 100644 index 1ebee20d..00000000 --- a/middleware/mock_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestMockHelper(t *testing.T) { - asserts := assert.New(t) - MockHelperFunc := MockHelper() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - - // 写入session - { - SessionMock["test"] = "pass" - Session("test")(c) - MockHelperFunc(c) - asserts.Equal("pass", util.GetSession(c, "test").(string)) - } - - // 写入context - { - ContextMock["test"] = "pass" - MockHelperFunc(c) - test, exist := c.Get("test") - asserts.True(exist) - asserts.Equal("pass", test.(string)) - - } -} diff --git a/middleware/session.go b/middleware/session.go index db90755b..aa1fae7d 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -1,14 +1,14 @@ package middleware import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/sessionstore" "net/http" "strings" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) @@ -16,11 +16,11 @@ import ( // Store session存储 var Store sessions.Store -// Session 初始化session -func Session(secret string) gin.HandlerFunc { - // Redis设置不为空,且非测试模式时使用Redis - Store = sessionstore.NewStore(cache.Store, []byte(secret)) +const SessionName = "cloudreve-session" +// Session 初始化session +func Session(dep dependency.Dep) gin.HandlerFunc { + Store = sessionstore.NewStore(dep.KV(), []byte(dep.ConfigProvider().System().SessionSecret)) sameSiteMode := http.SameSiteDefaultMode switch strings.ToLower(conf.CORSConfig.SameSite) { case "default": @@ -42,7 +42,7 @@ func Session(secret string) gin.HandlerFunc { Secure: conf.CORSConfig.Secure, }) - return sessions.Sessions("cloudreve-session", Store) + return sessions.Sessions(SessionName, Store) } // CSRFInit 初始化CSRF标记 @@ -61,7 +61,7 @@ func CSRFCheck() gin.HandlerFunc { return } - c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "Invalid origin", nil)) + c.JSON(200, serializer.ErrDeprecated(serializer.CodeNoPermissionErr, "Invalid origin", nil)) c.Abort() } } diff --git a/middleware/session_test.go b/middleware/session_test.go deleted file mode 100644 index 9fbe0d21..00000000 --- a/middleware/session_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestSession(t *testing.T) { - asserts := assert.New(t) - - { - handler := Session("2333") - asserts.NotNil(handler) - asserts.NotNil(Store) - asserts.IsType(emptyFunc(), handler) - } -} - -func emptyFunc() gin.HandlerFunc { - return func(c *gin.Context) {} -} - -func TestCSRFInit(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - sessionFunc := Session("233") - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFInit()(c) - asserts.True(util.GetSession(c, "CSRF").(bool)) - } -} - -func TestCSRFCheck(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - sessionFunc := Session("233") - - // 通过检查 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFInit()(c) - CSRFCheck()(c) - asserts.False(c.IsAborted()) - } - - // 未通过检查 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFCheck()(c) - asserts.True(c.IsAborted()) - } -} diff --git a/middleware/share.go b/middleware/share.go deleted file mode 100644 index 488b703d..00000000 --- a/middleware/share.go +++ /dev/null @@ -1,133 +0,0 @@ -package middleware - -import ( - "fmt" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" -) - -// ShareOwner 检查当前登录用户是否为分享所有者 -func ShareOwner() gin.HandlerFunc { - return func(c *gin.Context) { - var user *model.User - if userCtx, ok := c.Get("user"); ok { - user = userCtx.(*model.User) - } else { - c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "", nil)) - c.Abort() - return - } - - if share, ok := c.Get("share"); ok { - if share.(*model.Share).Creator().ID != user.ID { - c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil)) - c.Abort() - return - } - } - - c.Next() - } -} - -// ShareAvailable 检查分享是否可用 -func ShareAvailable() gin.HandlerFunc { - return func(c *gin.Context) { - var user *model.User - if userCtx, ok := c.Get("user"); ok { - user = userCtx.(*model.User) - } else { - user = model.NewAnonymousUser() - } - - share := model.GetShareByHashID(c.Param("id")) - - if share == nil || !share.IsAvailable() { - c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil)) - c.Abort() - return - } - - c.Set("user", user) - c.Set("share", share) - c.Next() - } -} - -// ShareCanPreview 检查分享是否可被预览 -func ShareCanPreview() gin.HandlerFunc { - return func(c *gin.Context) { - if share, ok := c.Get("share"); ok { - if share.(*model.Share).PreviewEnabled { - c.Next() - return - } - c.JSON(200, serializer.Err(serializer.CodeDisabledSharePreview, "", - nil)) - c.Abort() - return - } - c.Abort() - } -} - -// CheckShareUnlocked 检查分享是否已解锁 -func CheckShareUnlocked() gin.HandlerFunc { - return func(c *gin.Context) { - if shareCtx, ok := c.Get("share"); ok { - share := shareCtx.(*model.Share) - // 分享是否已解锁 - if share.Password != "" { - sessionKey := fmt.Sprintf("share_unlock_%d", share.ID) - unlocked := util.GetSession(c, sessionKey) != nil - if !unlocked { - c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, - "", nil)) - c.Abort() - return - } - } - - c.Next() - return - } - c.Abort() - } -} - -// BeforeShareDownload 分享被下载前的检查 -func BeforeShareDownload() gin.HandlerFunc { - return func(c *gin.Context) { - if shareCtx, ok := c.Get("share"); ok { - if userCtx, ok := c.Get("user"); ok { - share := shareCtx.(*model.Share) - user := userCtx.(*model.User) - - // 检查用户是否可以下载此分享的文件 - err := share.CanBeDownloadBy(user) - if err != nil { - c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(), - nil)) - c.Abort() - return - } - - // 对积分、下载次数进行更新 - err = share.DownloadBy(user, c) - if err != nil { - c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(), - nil)) - c.Abort() - return - } - - c.Next() - return - } - } - c.Abort() - } -} diff --git a/middleware/share_test.go b/middleware/share_test.go deleted file mode 100644 index 129076b5..00000000 --- a/middleware/share_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package middleware - -import ( - "net/http/httptest" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestShareAvailable(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareAvailable() - - // 分享不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "empty"}, - } - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 通过 - { - conf.SystemConfig.HashIDSalt = "" - // 用户组 - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - mock.ExpectQuery("SELECT(.+)shares(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "remain_downloads", "source_id"}). - AddRow(1, 1, 2), - ) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "x9T4"}, - } - testFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - asserts.NotNil(c.Get("user")) - asserts.NotNil(c.Get("share")) - } -} - -func TestShareCanPreview(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareCanPreview() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 可以预览 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{PreviewEnabled: true}) - testFunc(c) - asserts.False(c.IsAborted()) - } - - // 未开启预览 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{PreviewEnabled: false}) - testFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCheckShareUnlocked(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := CheckShareUnlocked() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 无密码 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.False(c.IsAborted()) - } - -} - -func TestBeforeShareDownload(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := BeforeShareDownload() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 用户不能下载 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{ - Group: model.Group{OptionsSerialized: model.GroupOption{}}, - }) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 可以下载 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{ - Model: gorm.Model{ID: 1}, - Group: model.Group{OptionsSerialized: model.GroupOption{ - ShareDownload: true, - }}, - }) - testFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestShareOwner(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareOwner() - - // 未登录 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 非用户所创建分享 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}}) - c.Set("user", &model.User{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 正常 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{}) - testFunc(c) - asserts.False(c.IsAborted()) - } -} diff --git a/middleware/wopi.go b/middleware/wopi.go index 41b8c01d..0d46866a 100644 --- a/middleware/wopi.go +++ b/middleware/wopi.go @@ -1,22 +1,21 @@ package middleware import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/wopi" "github.com/gin-gonic/gin" "net/http" "strings" ) -const ( - WopiSessionCtx = "wopi_session" -) - // WopiWriteAccess validates if write access is obtained. func WopiWriteAccess() gin.HandlerFunc { return func(c *gin.Context) { - session := c.MustGet(WopiSessionCtx).(*wopi.SessionCache) + session := c.MustGet(wopi.WopiSessionCtx).(*wopi.SessionCache) if session.Action != wopi.ActionEdit { c.Status(http.StatusNotFound) c.Header(wopi.ServerErrorHeader, "read-only access") @@ -28,8 +27,12 @@ func WopiWriteAccess() gin.HandlerFunc { } } -func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc { +func ViewerSessionValidation() gin.HandlerFunc { return func(c *gin.Context) { + dep := dependency.FromContext(c) + store := dep.KV() + settings := dep.SettingProvider() + accessToken := strings.Split(c.Query(wopi.AccessTokenQuery), ".") if len(accessToken) != 2 { c.Status(http.StatusForbidden) @@ -38,7 +41,7 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc { return } - sessionRaw, exist := store.Get(wopi.SessionCachePrefix + accessToken[0]) + sessionRaw, exist := store.Get(manager.ViewerSessionCachePrefix + accessToken[0]) if !exist { c.Status(http.StatusForbidden) c.Header(wopi.ServerErrorHeader, "invalid access token") @@ -46,25 +49,47 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc { return } - session := sessionRaw.(wopi.SessionCache) - user, err := model.GetActiveUserByID(session.UserID) - if err != nil { + session := sessionRaw.(manager.ViewerSessionCache) + if err := SetUserCtx(c, session.UserID); err != nil { c.Status(http.StatusInternalServerError) c.Header(wopi.ServerErrorHeader, "user not found") c.Abort() return } - fileID := c.MustGet("object_id").(uint) - if fileID != session.FileID { + fileId := hashid.FromContext(c) + if fileId != session.FileID { + c.Status(http.StatusForbidden) + c.Header(wopi.ServerErrorHeader, "invalid file") + c.Abort() + return + } + + // Check if the viewer is still available + viewers := settings.FileViewers(c) + var v *setting.Viewer + for _, group := range viewers { + for _, viewer := range group.Viewers { + if viewer.ID == session.ViewerID && !viewer.Disabled { + v = &viewer + break + } + } + + if v != nil { + break + } + } + + if v == nil { c.Status(http.StatusInternalServerError) - c.Header(wopi.ServerErrorHeader, "file not found") + c.Header(wopi.ServerErrorHeader, "viewer not found") c.Abort() return } - c.Set("user", &user) - c.Set(WopiSessionCtx, &session) + util.WithValue(c, manager.ViewerCtx{}, v) + util.WithValue(c, manager.ViewerSessionCacheCtx{}, &session) c.Next() } } diff --git a/middleware/wopi_test.go b/middleware/wopi_test.go deleted file mode 100644 index c6ca3270..00000000 --- a/middleware/wopi_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package middleware - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestWopiWriteAccess(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := WopiWriteAccess() - - // deny preview only session - { - c, _ := gin.CreateTestContext(rec) - c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // pass - { - c, _ := gin.CreateTestContext(rec) - c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit}) - testFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestWopiAccessValidation(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - mockWopi := &wopimock.WopiClientMock{} - mockCache := cache.NewMemoStore() - testFunc := WopiAccessValidation(mockWopi, mockCache) - - // malformed access token - { - c, _ := gin.CreateTestContext(rec) - c.AddParam(wopi.AccessTokenQuery, "000") - testFunc(c) - asserts.True(c.IsAborted()) - } - - // session key not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - testFunc(c) - asserts.True(c.IsAborted()) - } - - // user key not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - testFunc(c) - asserts.True(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // file not found - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - c.Set("object_id", uint(0)) - testFunc(c) - asserts.True(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // all pass - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - c.Set("object_id", uint(1)) - testFunc(c) - asserts.False(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotPanics(func() { - c.MustGet(WopiSessionCtx) - }) - asserts.NotPanics(func() { - c.MustGet("user") - }) - } -} diff --git a/models/defaults.go b/models/defaults.go deleted file mode 100644 index afba508f..00000000 --- a/models/defaults.go +++ /dev/null @@ -1,146 +0,0 @@ -package model - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" -) - -var defaultSettings = []Setting{ - {Name: "siteURL", Value: `http://localhost`, Type: "basic"}, - {Name: "siteName", Value: `Cloudreve`, Type: "basic"}, - {Name: "register_enabled", Value: `1`, Type: "register"}, - {Name: "default_group", Value: `2`, Type: "register"}, - {Name: "siteKeywords", Value: `Cloudreve, cloud storage`, Type: "basic"}, - {Name: "siteDes", Value: `Cloudreve`, Type: "basic"}, - {Name: "siteTitle", Value: `Inclusive cloud storage for everyone`, Type: "basic"}, - {Name: "siteScript", Value: ``, Type: "basic"}, - {Name: "siteID", Value: uuid.Must(uuid.NewV4()).String(), Type: "basic"}, - {Name: "fromName", Value: `Cloudreve`, Type: "mail"}, - {Name: "mail_keepalive", Value: `30`, Type: "mail"}, - {Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"}, - {Name: "smtpHost", Value: `smtp.mxhichina.com`, Type: "mail"}, - {Name: "smtpPort", Value: `25`, Type: "mail"}, - {Name: "replyTo", Value: `abslant@126.com`, Type: "mail"}, - {Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"}, - {Name: "smtpPass", Value: ``, Type: "mail"}, - {Name: "smtpEncryption", Value: `0`, Type: "mail"}, - {Name: "maxEditSize", Value: `52428800`, Type: "file_edit"}, - {Name: "archive_timeout", Value: `600`, Type: "timeout"}, - {Name: "download_timeout", Value: `600`, Type: "timeout"}, - {Name: "preview_timeout", Value: `600`, Type: "timeout"}, - {Name: "doc_preview_timeout", Value: `600`, Type: "timeout"}, - {Name: "upload_session_timeout", Value: `86400`, Type: "timeout"}, - {Name: "slave_api_timeout", Value: `60`, Type: "timeout"}, - {Name: "slave_node_retry", Value: `3`, Type: "slave"}, - {Name: "slave_ping_interval", Value: `60`, Type: "slave"}, - {Name: "slave_recover_interval", Value: `120`, Type: "slave"}, - {Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"}, - {Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"}, - {Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"}, - {Name: "onedrive_callback_check", Value: `20`, Type: "timeout"}, - {Name: "folder_props_timeout", Value: `300`, Type: "timeout"}, - {Name: "chunk_retries", Value: `5`, Type: "retry"}, - {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"}, - {Name: "reset_after_upload_failed", Value: `0`, Type: "upload"}, - {Name: "use_temp_chunk_buffer", Value: `1`, Type: "upload"}, - {Name: "login_captcha", Value: `0`, Type: "login"}, - {Name: "reg_captcha", Value: `0`, Type: "login"}, - {Name: "email_active", Value: `0`, Type: "register"}, - {Name: "mail_activation_template", Value: `激活您的账户
激活{siteTitle}账户
亲爱的{userName}
感谢您注册{siteTitle},请点击下方按钮完成账户激活。
激活账户
感谢您选择{siteTitle}。
`, Type: "mail_template"}, - {Name: "forget_captcha", Value: `0`, Type: "login"}, - {Name: "mail_reset_pwd_template", Value: `重设密码
重设{siteTitle}密码
亲爱的{userName}
请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。
重设密码
感谢您选择{siteTitle}。
`, Type: "mail_template"}, - {Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"}, - {Name: "hot_share_num", Value: `10`, Type: "share"}, - {Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"}, - {Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"}, - {Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"}, - {Name: "max_worker_num", Value: `10`, Type: "task"}, - {Name: "max_parallel_transfer", Value: `4`, Type: "task"}, - {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, - {Name: "temp_path", Value: "temp", Type: "path"}, - {Name: "avatar_path", Value: "avatar", Type: "path"}, - {Name: "avatar_size", Value: "2097152", Type: "avatar"}, - {Name: "avatar_size_l", Value: "200", Type: "avatar"}, - {Name: "avatar_size_m", Value: "130", Type: "avatar"}, - {Name: "avatar_size_s", Value: "50", Type: "avatar"}, - {Name: "home_view_method", Value: "icon", Type: "view"}, - {Name: "share_view_method", Value: "list", Type: "view"}, - {Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"}, - {Name: "cron_recycle_upload_session", Value: "@every 1h30m", Type: "cron"}, - {Name: "authn_enabled", Value: "0", Type: "authn"}, - {Name: "captcha_type", Value: "normal", Type: "captcha"}, - {Name: "captcha_height", Value: "60", Type: "captcha"}, - {Name: "captcha_width", Value: "240", Type: "captcha"}, - {Name: "captcha_mode", Value: "3", Type: "captcha"}, - {Name: "captcha_ComplexOfNoiseText", Value: "0", Type: "captcha"}, - {Name: "captcha_ComplexOfNoiseDot", Value: "0", Type: "captcha"}, - {Name: "captcha_IsShowHollowLine", Value: "0", Type: "captcha"}, - {Name: "captcha_IsShowNoiseDot", Value: "1", Type: "captcha"}, - {Name: "captcha_IsShowNoiseText", Value: "0", Type: "captcha"}, - {Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"}, - {Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"}, - {Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"}, - {Name: "captcha_ReCaptchaKey", Value: "defaultKey", Type: "captcha"}, - {Name: "captcha_ReCaptchaSecret", Value: "defaultSecret", Type: "captcha"}, - {Name: "captcha_TCaptcha_CaptchaAppId", Value: "", Type: "captcha"}, - {Name: "captcha_TCaptcha_AppSecretKey", Value: "", Type: "captcha"}, - {Name: "captcha_TCaptcha_SecretId", Value: "", Type: "captcha"}, - {Name: "captcha_TCaptcha_SecretKey", Value: "", Type: "captcha"}, - {Name: "thumb_width", Value: "400", Type: "thumb"}, - {Name: "thumb_height", Value: "300", Type: "thumb"}, - {Name: "thumb_file_suffix", Value: "._thumb", Type: "thumb"}, - {Name: "thumb_max_task_count", Value: "-1", Type: "thumb"}, - {Name: "thumb_encode_method", Value: "jpg", Type: "thumb"}, - {Name: "thumb_gc_after_gen", Value: "0", Type: "thumb"}, - {Name: "thumb_encode_quality", Value: "85", Type: "thumb"}, - {Name: "thumb_builtin_enabled", Value: "1", Type: "thumb"}, - {Name: "thumb_vips_enabled", Value: "0", Type: "thumb"}, - {Name: "thumb_ffmpeg_enabled", Value: "0", Type: "thumb"}, - {Name: "thumb_vips_path", Value: "vips", Type: "thumb"}, - {Name: "thumb_vips_exts", Value: "csv,mat,img,hdr,pbm,pgm,ppm,pfm,pnm,svg,svgz,j2k,jp2,jpt,j2c,jpc,gif,png,jpg,jpeg,jpe,webp,tif,tiff,fits,fit,fts,exr,jxl,pdf,heic,heif,avif,svs,vms,vmu,ndpi,scn,mrxs,svslide,bif,raw", Type: "thumb"}, - {Name: "thumb_ffmpeg_seek", Value: "00:00:01.00", Type: "thumb"}, - {Name: "thumb_ffmpeg_path", Value: "ffmpeg", Type: "thumb"}, - {Name: "thumb_ffmpeg_exts", Value: "3g2,3gp,asf,asx,avi,divx,flv,m2ts,m2v,m4v,mkv,mov,mp4,mpeg,mpg,mts,mxf,ogv,rm,swf,webm,wmv", Type: "thumb"}, - {Name: "thumb_libreoffice_path", Value: "soffice", Type: "thumb"}, - {Name: "thumb_libreoffice_enabled", Value: "0", Type: "thumb"}, - {Name: "thumb_libreoffice_exts", Value: "md,ods,ots,fods,uos,xlsx,xml,xls,xlt,dif,dbf,html,slk,csv,xlsm,docx,dotx,doc,dot,rtf,xlsm,xlst,xls,xlw,xlc,xlt,pptx,ppsx,potx,pomx,ppt,pps,ppm,pot,pom", Type: "thumb"}, - {Name: "thumb_proxy_enabled", Value: "0", Type: "thumb"}, - {Name: "thumb_proxy_policy", Value: "[]", Type: "thumb"}, - {Name: "thumb_max_src_size", Value: "31457280", Type: "thumb"}, - {Name: "thumb_libraw_path", Value: "simple_dcraw", Type: "thumb"}, - {Name: "thumb_libraw_enabled", Value: "0", Type: "thumb"}, - {Name: "thumb_libraw_exts", Value: "arw,raf,dng", Type: "thumb"}, - {Name: "pwa_small_icon", Value: "/static/img/favicon.ico", Type: "pwa"}, - {Name: "pwa_medium_icon", Value: "/static/img/logo192.png", Type: "pwa"}, - {Name: "pwa_large_icon", Value: "/static/img/logo512.png", Type: "pwa"}, - {Name: "pwa_display", Value: "standalone", Type: "pwa"}, - {Name: "pwa_theme_color", Value: "#000000", Type: "pwa"}, - {Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"}, - {Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"}, - {Name: "show_app_promotion", Value: "1", Type: "mobile"}, - {Name: "public_resource_maxage", Value: "86400", Type: "timeout"}, - {Name: "wopi_enabled", Value: "0", Type: "wopi"}, - {Name: "wopi_endpoint", Value: "", Type: "wopi"}, - {Name: "wopi_max_size", Value: "52428800", Type: "wopi"}, - {Name: "wopi_session_timeout", Value: "36000", Type: "wopi"}, -} - -func InitSlaveDefaults() { - for _, setting := range defaultSettings { - cache.Set("setting_"+setting.Name, setting.Value, -1) - } -} diff --git a/models/download.go b/models/download.go deleted file mode 100644 index dce50f3a..00000000 --- a/models/download.go +++ /dev/null @@ -1,128 +0,0 @@ -package model - -import ( - "encoding/json" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// Download 离线下载队列模型 -type Download struct { - gorm.Model - Status int // 任务状态 - Type int // 任务类型 - Source string `gorm:"type:text"` // 文件下载地址 - TotalSize uint64 // 文件大小 - DownloadedSize uint64 // 文件大小 - GID string `gorm:"size:32,index:gid"` // 任务ID - Speed int // 下载速度 - Parent string `gorm:"type:text"` // 存储目录 - Attrs string `gorm:"size:4294967295"` // 任务状态属性 - Error string `gorm:"type:text"` // 错误描述 - Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 - UserID uint // 发起者UID - TaskID uint // 对应的转存任务ID - NodeID uint // 处理任务的节点ID - - // 关联模型 - User *User `gorm:"PRELOAD:false,association_autoupdate:false"` - - // 数据库忽略字段 - StatusInfo rpc.StatusInfo `gorm:"-"` - Task *Task `gorm:"-"` - NodeName string `gorm:"-"` -} - -// AfterFind 找到下载任务后的钩子,处理Status结构 -func (task *Download) AfterFind() (err error) { - // 解析状态 - if task.Attrs != "" { - err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) - } - - if task.TaskID != 0 { - task.Task, _ = GetTasksByID(task.TaskID) - } - - return err -} - -// BeforeSave Save下载任务前的钩子 -func (task *Download) BeforeSave() (err error) { - // 解析状态 - if task.Attrs != "" { - err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) - } - return err -} - -// Create 创建离线下载记录 -func (task *Download) Create() (uint, error) { - if err := DB.Create(task).Error; err != nil { - util.Log().Warning("Failed to insert download record: %s", err) - return 0, err - } - return task.ID, nil -} - -// Save 更新 -func (task *Download) Save() error { - if err := DB.Save(task).Error; err != nil { - util.Log().Warning("Failed to update download record: %s", err) - return err - } - return nil -} - -// GetDownloadsByStatus 根据状态检索下载 -func GetDownloadsByStatus(status ...int) []Download { - var tasks []Download - DB.Where("status in (?)", status).Find(&tasks) - return tasks -} - -// GetDownloadsByStatusAndUser 根据状态检索和用户ID下载 -// page 为 0 表示列出所有,非零时分页 -func GetDownloadsByStatusAndUser(page, uid uint, status ...int) []Download { - var tasks []Download - dbChain := DB - if page > 0 { - dbChain = dbChain.Limit(10).Offset((page - 1) * 10).Order("updated_at DESC") - } - dbChain.Where("user_id = ? and status in (?)", uid, status).Find(&tasks) - return tasks -} - -// GetDownloadByGid 根据GID和用户ID查找下载 -func GetDownloadByGid(gid string, uid uint) (*Download, error) { - download := &Download{} - result := DB.Where("user_id = ? and g_id = ?", uid, gid).First(download) - return download, result.Error -} - -// GetOwner 获取下载任务所属用户 -func (task *Download) GetOwner() *User { - if task.User == nil { - if user, err := GetUserByID(task.UserID); err == nil { - return &user - } - } - return task.User -} - -// Delete 删除离线下载记录 -func (download *Download) Delete() error { - return DB.Model(download).Delete(download).Error -} - -// GetNodeID 返回任务所属节点ID -func (task *Download) GetNodeID() uint { - // 兼容3.4版本之前生成的下载记录 - if task.NodeID == 0 { - return 1 - } - - return task.NodeID -} diff --git a/models/download_test.go b/models/download_test.go deleted file mode 100644 index 367afb78..00000000 --- a/models/download_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestDownload_Create(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - download := Download{GID: "1"} - id, err := download.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - download := Download{GID: "1"} - id, err := download.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestDownload_AfterFind(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - download := Download{Attrs: `{"gid":"123"}`} - err := download.AfterFind() - asserts.NoError(err) - asserts.Equal("123", download.StatusInfo.Gid) - } - - // 忽略空值 - { - download := Download{Attrs: ``} - err := download.AfterFind() - asserts.NoError(err) - asserts.Equal("", download.StatusInfo.Gid) - } - - // 解析失败 - { - download := Download{Attrs: `?`} - err := download.BeforeSave() - asserts.Error(err) - asserts.Equal("", download.StatusInfo.Gid) - } - -} - -func TestDownload_Save(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - download := Download{ - Model: gorm.Model{ - ID: 1, - }, - Attrs: `{"gid":"123"}`, - } - err := download.Save() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("123", download.StatusInfo.Gid) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - download := Download{ - Model: gorm.Model{ - ID: 1, - }, - } - err := download.Save() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestGetDownloadsByStatus(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(0, 1).WillReturnRows(sqlmock.NewRows([]string{"gid"}).AddRow("0").AddRow("1")) - res := GetDownloadsByStatus(0, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) -} - -func TestGetDownloadByGid(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(2, "gid").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) - res, err := GetDownloadByGid("gid", 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(res.GID, "1") -} - -func TestDownload_GetOwner(t *testing.T) { - asserts := assert.New(t) - - // 已经有User对象 - { - download := &Download{User: &User{Nick: "nick"}} - user := download.GetOwner() - asserts.NotNil(user) - asserts.Equal("nick", user.Nick) - } - - // 无User对象 - { - download := &Download{UserID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"nick"}).AddRow("nick")) - user := download.GetOwner() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(user) - asserts.Equal("nick", user.Nick) - } -} - -func TestGetDownloadsByStatusAndUser(t *testing.T) { - asserts := assert.New(t) - - // 列出全部 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) - res := GetDownloadsByStatusAndUser(0, 1, 1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - } - - // 列出全部,分页 - { - mock.ExpectQuery("SELECT(.+)DESC(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) - res := GetDownloadsByStatusAndUser(2, 1, 1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - } -} - -func TestDownload_Delete(t *testing.T) { - asserts := assert.New(t) - share := Download{} - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Delete() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - -} - -func TestDownload_GetNodeID(t *testing.T) { - a := assert.New(t) - record := Download{} - - // compatible with 3.4 - a.EqualValues(1, record.GetNodeID()) - - record.NodeID = 5 - a.EqualValues(5, record.GetNodeID()) -} diff --git a/models/file.go b/models/file.go deleted file mode 100644 index bfe49cba..00000000 --- a/models/file.go +++ /dev/null @@ -1,472 +0,0 @@ -package model - -import ( - "encoding/gob" - "encoding/json" - "errors" - "fmt" - "path" - "path/filepath" - "strings" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// File 文件 -type File struct { - // 表字段 - gorm.Model - Name string `gorm:"unique_index:idx_only_one"` - SourceName string `gorm:"type:text"` - UserID uint `gorm:"index:user_id;unique_index:idx_only_one"` - Size uint64 - PicInfo string - FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"` - PolicyID uint - UploadSessionID *string `gorm:"index:session_id;unique_index:session_only_one"` - Metadata string `gorm:"type:text"` - - // 关联模型 - Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` - - // 数据库忽略字段 - Position string `gorm:"-"` - MetadataSerialized map[string]string `gorm:"-"` -} - -// Thumb related metadata -const ( - ThumbStatusNotExist = "" - ThumbStatusExist = "exist" - ThumbStatusNotAvailable = "not_available" - - ThumbStatusMetadataKey = "thumb_status" - ThumbSidecarMetadataKey = "thumb_sidecar" - - ChecksumMetadataKey = "webdav_checksum" -) - -func init() { - // 注册缓存用到的复杂结构 - gob.Register(File{}) -} - -// Create 创建文件记录 -func (file *File) Create() error { - tx := DB.Begin() - - if err := tx.Create(file).Error; err != nil { - util.Log().Warning("Failed to insert file record: %s", err) - tx.Rollback() - return err - } - - user := &User{} - user.ID = file.UserID - if err := user.ChangeStorage(tx, "+", file.Size); err != nil { - tx.Rollback() - return err - } - - return tx.Commit().Error -} - -// AfterFind 找到文件后的钩子 -func (file *File) AfterFind() (err error) { - // 反序列化文件元数据 - if file.Metadata != "" { - err = json.Unmarshal([]byte(file.Metadata), &file.MetadataSerialized) - } else { - file.MetadataSerialized = make(map[string]string) - } - - return -} - -// BeforeSave Save策略前的钩子 -func (file *File) BeforeSave() (err error) { - if len(file.MetadataSerialized) > 0 { - metaValue, err := json.Marshal(&file.MetadataSerialized) - file.Metadata = string(metaValue) - return err - } - - return nil -} - -// GetChildFile 查找目录下名为name的子文件 -func (folder *Folder) GetChildFile(name string) (*File, error) { - var file File - result := DB.Where("folder_id = ? AND name = ?", folder.ID, name).Find(&file) - - if result.Error == nil { - file.Position = path.Join(folder.Position, folder.Name) - } - return &file, result.Error -} - -// GetChildFiles 查找目录下子文件 -func (folder *Folder) GetChildFiles() ([]File, error) { - var files []File - result := DB.Where("folder_id = ?", folder.ID).Find(&files) - - if result.Error == nil { - for i := 0; i < len(files); i++ { - files[i].Position = path.Join(folder.Position, folder.Name) - } - } - return files, result.Error -} - -// GetFilesByIDs 根据文件ID批量获取文件, -// UID为0表示忽略用户,只根据文件ID检索 -func GetFilesByIDs(ids []uint, uid uint) ([]File, error) { - return GetFilesByIDsFromTX(DB, ids, uid) -} - -func GetFilesByIDsFromTX(tx *gorm.DB, ids []uint, uid uint) ([]File, error) { - var files []File - var result *gorm.DB - if uid == 0 { - result = tx.Where("id in (?)", ids).Find(&files) - } else { - result = tx.Where("id in (?) AND user_id = ?", ids, uid).Find(&files) - } - return files, result.Error -} - -// GetFilesByKeywords 根据关键字搜索文件, -// UID为0表示忽略用户,只根据文件ID检索. 如果 parents 非空, 则只限制在 parent 包含的目录下搜索 -func GetFilesByKeywords(uid uint, parents []uint, keywords ...interface{}) ([]File, error) { - var ( - files []File - result = DB - conditions string - ) - - // 生成查询条件 - for i := 0; i < len(keywords); i++ { - conditions += "name like ?" - if i != len(keywords)-1 { - conditions += " or " - } - } - - if uid != 0 { - result = result.Where("user_id = ?", uid) - } - - if len(parents) > 0 { - result = result.Where("folder_id in (?)", parents) - } - - result = result.Where("("+conditions+")", keywords...).Find(&files) - - return files, result.Error -} - -// GetChildFilesOfFolders 批量检索目录子文件 -func GetChildFilesOfFolders(folders *[]Folder) ([]File, error) { - // 将所有待检索目录ID抽离,以便检索文件 - folderIDs := make([]uint, 0, len(*folders)) - for _, value := range *folders { - folderIDs = append(folderIDs, value.ID) - } - - // 检索文件 - var files []File - result := DB.Where("folder_id in (?)", folderIDs).Find(&files) - return files, result.Error -} - -// GetUploadPlaceholderFiles 获取所有上传占位文件 -// UID为0表示忽略用户 -func GetUploadPlaceholderFiles(uid uint) []*File { - query := DB - if uid != 0 { - query = query.Where("user_id = ?", uid) - } - - var files []*File - query.Where("upload_session_id is not NULL").Find(&files) - return files -} - -// GetPolicy 获取文件所属策略 -func (file *File) GetPolicy() *Policy { - if file.Policy.Model.ID == 0 { - file.Policy, _ = GetPolicyByID(file.PolicyID) - } - return &file.Policy -} - -// RemoveFilesWithSoftLinks 去除给定的文件列表中有软链接的文件 -func RemoveFilesWithSoftLinks(files []File) ([]File, error) { - // 结果值 - filteredFiles := make([]File, 0) - - if len(files) == 0 { - return filteredFiles, nil - } - - // 查询软链接的文件 - filesWithSoftLinks := make([]File, 0) - for _, file := range files { - var softLinkFile File - res := DB. - Where("source_name = ? and policy_id = ? and id != ?", file.SourceName, file.PolicyID, file.ID). - First(&softLinkFile) - if res.Error == nil { - filesWithSoftLinks = append(filesWithSoftLinks, softLinkFile) - } - } - - // 过滤具有软连接的文件 - // TODO: 优化复杂度 - if len(filesWithSoftLinks) == 0 { - filteredFiles = files - } else { - for i := 0; i < len(files); i++ { - finder := false - for _, value := range filesWithSoftLinks { - if value.PolicyID == files[i].PolicyID && value.SourceName == files[i].SourceName { - finder = true - break - } - } - if !finder { - filteredFiles = append(filteredFiles, files[i]) - } - - } - } - - return filteredFiles, nil - -} - -// DeleteFiles 批量删除文件记录并归还容量 -func DeleteFiles(files []*File, uid uint) error { - tx := DB.Begin() - user := &User{} - user.ID = uid - var size uint64 - for _, file := range files { - if uid > 0 && file.UserID != uid { - tx.Rollback() - return errors.New("user id not consistent") - } - - result := tx.Unscoped().Where("size = ?", file.Size).Delete(file) - if result.Error != nil { - tx.Rollback() - return result.Error - } - - if result.RowsAffected == 0 { - tx.Rollback() - return errors.New("file size is dirty") - } - - size += file.Size - } - - if uid > 0 { - if err := user.ChangeStorage(tx, "-", size); err != nil { - tx.Rollback() - return err - } - } - - return tx.Commit().Error -} - -// GetFilesByParentIDs 根据父目录ID查找文件 -func GetFilesByParentIDs(ids []uint, uid uint) ([]File, error) { - files := make([]File, 0, len(ids)) - result := DB.Where("user_id = ? and folder_id in (?)", uid, ids).Find(&files) - return files, result.Error -} - -// GetFilesByUploadSession 查找上传会话对应的文件 -func GetFilesByUploadSession(sessionID string, uid uint) (*File, error) { - file := File{} - result := DB.Where("user_id = ? and upload_session_id = ?", uid, sessionID).Find(&file) - return &file, result.Error -} - -// Rename 重命名文件 -func (file *File) Rename(new string) error { - if file.MetadataSerialized[ThumbStatusMetadataKey] == ThumbStatusNotAvailable { - if !strings.EqualFold(filepath.Ext(new), filepath.Ext(file.Name)) { - // Reset thumb status for new ext name. - if err := file.resetThumb(); err != nil { - return err - } - } - } - - return DB.Model(&file).Set("gorm:association_autoupdate", false).Updates(map[string]interface{}{ - "name": new, - "metadata": file.Metadata, - }).Error -} - -// UpdatePicInfo 更新文件的图像信息 -func (file *File) UpdatePicInfo(value string) error { - return DB.Model(&file).Set("gorm:association_autoupdate", false).UpdateColumns(File{PicInfo: value}).Error -} - -// UpdateMetadata 新增或修改文件的元信息 -func (file *File) UpdateMetadata(data map[string]string) error { - if file.MetadataSerialized == nil { - file.MetadataSerialized = make(map[string]string) - } - - for k, v := range data { - file.MetadataSerialized[k] = v - } - metaValue, err := json.Marshal(&file.MetadataSerialized) - if err != nil { - return err - } - - return DB.Model(&file).Set("gorm:association_autoupdate", false).UpdateColumns(File{Metadata: string(metaValue)}).Error -} - -// UpdateSize 更新文件的大小信息 -// TODO: 全局锁 -func (file *File) UpdateSize(value uint64) error { - tx := DB.Begin() - var sizeDelta uint64 - operator := "+" - user := User{} - user.ID = file.UserID - if value > file.Size { - sizeDelta = value - file.Size - } else { - operator = "-" - sizeDelta = file.Size - value - } - - if err := file.resetThumb(); err != nil { - tx.Rollback() - return err - } - - if res := tx.Model(&file). - Where("size = ?", file.Size). - Set("gorm:association_autoupdate", false). - Updates(map[string]interface{}{ - "size": value, - "metadata": file.Metadata, - }); res.Error != nil { - tx.Rollback() - return res.Error - } - - if err := user.ChangeStorage(tx, operator, sizeDelta); err != nil { - tx.Rollback() - return err - } - - file.Size = value - return tx.Commit().Error -} - -// UpdateSourceName 更新文件的源文件名 -func (file *File) UpdateSourceName(value string) error { - if err := file.resetThumb(); err != nil { - return err - } - - return DB.Model(&file).Set("gorm:association_autoupdate", false).Updates(map[string]interface{}{ - "source_name": value, - "metadata": file.Metadata, - }).Error -} - -func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error { - file.UploadSessionID = nil - if lastModified != nil { - file.UpdatedAt = *lastModified - } - - return DB.Model(file).UpdateColumns(map[string]interface{}{ - "upload_session_id": file.UploadSessionID, - "updated_at": file.UpdatedAt, - "pic_info": picInfo, - }).Error -} - -// CanCopy 返回文件是否可被复制 -func (file *File) CanCopy() bool { - return file.UploadSessionID == nil -} - -// CreateOrGetSourceLink creates a SourceLink model. If the given model exists, the existing -// model will be returned. -func (file *File) CreateOrGetSourceLink() (*SourceLink, error) { - res := &SourceLink{} - err := DB.Set("gorm:auto_preload", true).Where("file_id = ?", file.ID).Find(&res).Error - if err == nil && res.ID > 0 { - return res, nil - } - - res.FileID = file.ID - res.Name = file.Name - if err := DB.Save(res).Error; err != nil { - return nil, fmt.Errorf("failed to insert SourceLink: %w", err) - } - - res.File = *file - return res, nil -} - -func (file *File) resetThumb() error { - if _, ok := file.MetadataSerialized[ThumbStatusMetadataKey]; !ok { - return nil - } - - delete(file.MetadataSerialized, ThumbStatusMetadataKey) - metaValue, err := json.Marshal(&file.MetadataSerialized) - file.Metadata = string(metaValue) - return err -} - -/* - 实现 webdav.FileInfo 接口 -*/ - -func (file *File) GetName() string { - return file.Name -} - -func (file *File) GetSize() uint64 { - return file.Size -} -func (file *File) ModTime() time.Time { - return file.UpdatedAt -} - -func (file *File) IsDir() bool { - return false -} - -func (file *File) GetPosition() string { - return file.Position -} - -// ShouldLoadThumb returns if file explorer should try to load thumbnail for this file. -// `True` does not guarantee the load request will success in next step, but the client -// should try to load and fallback to default placeholder in case error returned. -func (file *File) ShouldLoadThumb() bool { - return file.MetadataSerialized[ThumbStatusMetadataKey] != ThumbStatusNotAvailable -} - -// return sidecar thumb file name -func (file *File) ThumbFile() string { - return file.SourceName + GetSettingByNameWithDefault("thumb_file_suffix", "._thumb") -} diff --git a/models/file_test.go b/models/file_test.go deleted file mode 100644 index 83198fc9..00000000 --- a/models/file_test.go +++ /dev/null @@ -1,785 +0,0 @@ -package model - -import ( - "errors" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFile_Create(t *testing.T) { - asserts := assert.New(t) - file := File{ - Name: "123", - } - - // 无法插入文件记录 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := file.Create() - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 无法更新用户容量 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := file.Create() - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - err := file.Create() - asserts.NoError(err) - asserts.Equal(uint(5), file.ID) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_AfterFind(t *testing.T) { - a := assert.New(t) - - // metadata not empty - { - file := File{ - Name: "123", - Metadata: "{\"name\":\"123\"}", - } - - a.NoError(file.AfterFind()) - a.Equal("123", file.MetadataSerialized["name"]) - } - - // metadata empty - { - file := File{ - Name: "123", - Metadata: "", - } - a.Nil(file.MetadataSerialized) - a.NoError(file.AfterFind()) - a.NotNil(file.MetadataSerialized) - } -} - -func TestFile_BeforeSave(t *testing.T) { - a := assert.New(t) - - // metadata not empty - { - file := File{ - Name: "123", - MetadataSerialized: map[string]string{ - "name": "123", - }, - } - - a.NoError(file.BeforeSave()) - a.Equal("{\"name\":\"123\"}", file.Metadata) - } - - // metadata empty - { - file := File{ - Name: "123", - } - a.NoError(file.BeforeSave()) - a.Equal("", file.Metadata) - } -} - -func TestFolder_GetChildFile(t *testing.T) { - asserts := assert.New(t) - folder := Folder{Model: gorm.Model{ID: 1}, Name: "/"} - // 存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "1.txt"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - file, err := folder.GetChildFile("1.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("1.txt", file.Name) - asserts.Equal("/", file.Position) - } - - // 不存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "1.txt"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - _, err := folder.GetChildFile("1.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestFolder_GetChildFiles(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Model: gorm.Model{ - ID: 1, - }, - Position: "/123", - Name: "456", - } - - // 找不到 - mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) - files, err := folder.GetChildFiles() - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找到了 - mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) - files, err = folder.GetChildFiles() - asserts.NoError(err) - asserts.Len(files, 2) - asserts.Equal("/123/456", files[0].Position) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestGetFilesByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnError(errors.New("error")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Len(folders, 0) - } - - // 部分找到 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } - - // 忽略UID查找 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } -} - -func TestGetChildFilesOfFolders(t *testing.T) { - asserts := assert.New(t) - testFolder := []Folder{ - Folder{ - Model: gorm.Model{ID: 3}, - }, - Folder{ - Model: gorm.Model{ID: 4}, - }, Folder{ - Model: gorm.Model{ID: 5}, - }, - } - - // 出错 - { - mock.ExpectQuery("SELECT(.+)folder_id").WithArgs(3, 4, 5).WillReturnError(errors.New("not found")) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 找到2个 - { - mock.ExpectQuery("SELECT(.+)folder_id"). - WithArgs(3, 4, 5). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(3, "3"). - AddRow(4, "4"), - ) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.NoError(err) - asserts.Len(files, 2) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 全部找到 - { - mock.ExpectQuery("SELECT(.+)folder_id"). - WithArgs(3, 4, 5). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(3, "3"). - AddRow(4, "4"). - AddRow(5, "5"), - ) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.NoError(err) - asserts.Len(files, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestGetUploadPlaceholderFiles(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)upload_session_id(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - files := GetUploadPlaceholderFiles(1) - a.NoError(mock.ExpectationsWereMet()) - a.Len(files, 1) -} - -func TestFile_GetPolicy(t *testing.T) { - asserts := assert.New(t) - - // 空策略 - { - file := File{ - PolicyID: 23, - } - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(23, "name"), - ) - file.GetPolicy() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint(23), file.Policy.ID) - } - - // 非空策略 - { - file := File{ - PolicyID: 23, - Policy: Policy{Model: gorm.Model{ID: 24}}, - } - file.GetPolicy() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint(24), file.Policy.ID) - } -} - -func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) { - asserts := assert.New(t) - // 传入空 - { - mock.ExpectQuery("SELECT(.+)files(.+)") - file, err := RemoveFilesWithSoftLinks([]File{}) - asserts.Error(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(len(file), 0) - DB.Find(&File{}) - } -} - -func TestRemoveFilesWithSoftLinks(t *testing.T) { - asserts := assert.New(t) - files := []File{ - File{ - Model: gorm.Model{ID: 1}, - SourceName: "1.txt", - PolicyID: 23, - }, - File{ - Model: gorm.Model{ID: 2}, - SourceName: "2.txt", - PolicyID: 24, - }, - } - - // 传入空文件列表 - { - file, err := RemoveFilesWithSoftLinks([]File{}) - asserts.NoError(err) - asserts.Empty(file) - } - - // 全都没有 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files, file) - } - - // 第二个是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 24, "2.txt"), - ) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files[:1], file) - } - - // 第一个是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 23, "1.txt"), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files[1:], file) - } - // 全部是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 23, "1.txt"), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 24, "2.txt"), - ) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(file, 0) - } -} - -func TestDeleteFiles(t *testing.T) { - a := assert.New(t) - - // uid 不一致 - { - err := DeleteFiles([]*File{{UserID: 2}}, 1) - a.Contains("user id not consistent", err.Error()) - } - - // 删除失败 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFiles([]*File{{UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - } - - // 无法变更用户容量 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFiles([]*File{{UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - } - - // 文件脏读 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 0)) - mock.ExpectRollback() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - a.Contains("file size is dirty", err.Error()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WithArgs(uint64(3), sqlmock.AnyArg(), uint(1)).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.NoError(err) - } - - // 成功, 关联用户不存在 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 0) - a.NoError(mock.ExpectationsWereMet()) - a.NoError(err) - } -} - -func TestGetFilesByParentIDs(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 4, 5, 6). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "4.txt"). - AddRow(5, "5.txt"). - AddRow(6, "6.txt"), - ) - files, err := GetFilesByParentIDs([]uint{4, 5, 6}, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(files, 3) -} - -func TestGetFilesByUploadSession(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "sessionID"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}).AddRow(4, "4.txt")) - files, err := GetFilesByUploadSession("sessionID", 1) - a.NoError(err) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("4.txt", files.Name) -} - -func TestFile_Updates(t *testing.T) { - asserts := assert.New(t) - file := File{Model: gorm.Model{ID: 1}} - - // rename - { - // not reset thumb - { - file := File{Model: gorm.Model{ID: 1}} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // thumb not available, rename base name only - { - file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{ - ThumbStatusMetadataKey: ThumbStatusNotAvailable, - }, - Metadata: "{}"} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(ThumbStatusNotAvailable, file.MetadataSerialized[ThumbStatusMetadataKey]) - } - - // thumb not available, rename base name only - { - file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{ - ThumbStatusMetadataKey: ThumbStatusNotAvailable, - }} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.jpg", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName.jpg") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Empty(file.MetadataSerialized[ThumbStatusMetadataKey]) - } - } - - // UpdatePicInfo - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.UpdatePicInfo("1,1") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // UpdateSourceName - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.UpdateSourceName("newName") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } -} - -func TestFile_UpdateSize(t *testing.T) { - a := assert.New(t) - - // 增加成功 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 11, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)+(.+)").WithArgs(uint64(1), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(file.UpdateSize(11)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 减少成功 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 文件更新失败 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnError(errors.New("error")) - mock.ExpectRollback() - - a.Error(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 用户容量更新失败 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnError(errors.New("error")) - mock.ExpectRollback() - - a.Error(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_PopChunkToFile(t *testing.T) { - a := assert.New(t) - timeNow := time.Now() - file := File{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(file.PopChunkToFile(&timeNow, "1,1")) -} - -func TestFile_CanCopy(t *testing.T) { - a := assert.New(t) - file := File{} - a.True(file.CanCopy()) - file.UploadSessionID = &file.Name - a.False(file.CanCopy()) -} - -func TestFile_FileInfoInterface(t *testing.T) { - asserts := assert.New(t) - file := File{ - Model: gorm.Model{ - UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), - }, - Name: "test_name", - SourceName: "", - UserID: 0, - Size: 10, - PicInfo: "", - FolderID: 0, - PolicyID: 0, - Policy: Policy{}, - Position: "/test", - } - - name := file.GetName() - asserts.Equal("test_name", name) - - size := file.GetSize() - asserts.Equal(uint64(10), size) - - asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), file.ModTime()) - asserts.False(file.IsDir()) - asserts.Equal("/test", file.GetPosition()) -} - -func TestGetFilesByKeywords(t *testing.T) { - asserts := assert.New(t) - - // 未指定用户 - { - mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(0, nil, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 指定用户 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(1, nil, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 指定父目录 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 12, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(1, []uint{12}, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } -} - -func TestFile_CreateOrGetSourceLink(t *testing.T) { - a := assert.New(t) - file := &File{} - file.ID = 1 - - // 已存在,返回老的 SourceLink - { - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - res, err := file.CreateOrGetSourceLink() - a.NoError(err) - a.EqualValues(2, res.ID) - a.NoError(mock.ExpectationsWereMet()) - } - - // 不存在,插入失败 - { - expectedErr := errors.New("error") - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr) - mock.ExpectRollback() - res, err := file.CreateOrGetSourceLink() - a.Nil(res) - a.ErrorIs(err, expectedErr) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - res, err := file.CreateOrGetSourceLink() - a.NoError(err) - a.EqualValues(2, res.ID) - a.EqualValues(file.ID, res.File.ID) - a.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_UpdateMetadata(t *testing.T) { - a := assert.New(t) - file := &File{} - file.ID = 1 - - // 更新失败 - { - expectedErr := errors.New("error") - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnError(expectedErr) - mock.ExpectRollback() - a.ErrorIs(file.UpdateMetadata(map[string]string{"1": "1"}), expectedErr) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(file.UpdateMetadata(map[string]string{"1": "1"})) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("1", file.MetadataSerialized["1"]) - } -} - -func TestFile_ShouldLoadThumb(t *testing.T) { - a := assert.New(t) - file := &File{ - MetadataSerialized: map[string]string{}, - } - file.ID = 1 - - // 无缩略图 - { - file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusNotAvailable - a.False(file.ShouldLoadThumb()) - } - - // 有缩略图 - { - file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusExist - a.True(file.ShouldLoadThumb()) - } -} - -func TestFile_ThumbFile(t *testing.T) { - a := assert.New(t) - file := &File{ - SourceName: "test", - MetadataSerialized: map[string]string{}, - } - file.ID = 1 - - a.Equal("test._thumb", file.ThumbFile()) -} diff --git a/models/folder.go b/models/folder.go deleted file mode 100644 index 80f712c4..00000000 --- a/models/folder.go +++ /dev/null @@ -1,346 +0,0 @@ -package model - -import ( - "errors" - "path" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// Folder 目录 -type Folder struct { - // 表字段 - gorm.Model - Name string `gorm:"unique_index:idx_only_one_name"` - ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` - OwnerID uint `gorm:"index:owner_id"` - - // 数据库忽略字段 - Position string `gorm:"-"` - WebdavDstName string `gorm:"-"` -} - -// Create 创建目录 -func (folder *Folder) Create() (uint, error) { - if err := DB.FirstOrCreate(folder, *folder).Error; err != nil { - folder.Model = gorm.Model{} - err2 := DB.First(folder, *folder).Error - return folder.ID, err2 - } - - return folder.ID, nil -} - -// GetChild 返回folder下名为name的子目录,不存在则返回错误 -func (folder *Folder) GetChild(name string) (*Folder, error) { - var resFolder Folder - err := DB. - Where("parent_id = ? AND owner_id = ? AND name = ?", folder.ID, folder.OwnerID, name). - First(&resFolder).Error - - // 将子目录的路径传递下去 - if err == nil { - resFolder.Position = path.Join(folder.Position, folder.Name) - } - return &resFolder, err -} - -// TraceRoot 向上递归查找父目录 -func (folder *Folder) TraceRoot() error { - if folder.ParentID == nil { - return nil - } - - var parentFolder Folder - err := DB. - Where("id = ? AND owner_id = ?", folder.ParentID, folder.OwnerID). - First(&parentFolder).Error - - if err == nil { - err := parentFolder.TraceRoot() - folder.Position = path.Join(parentFolder.Position, parentFolder.Name) - return err - } - - return err -} - -// GetChildFolder 查找子目录 -func (folder *Folder) GetChildFolder() ([]Folder, error) { - var folders []Folder - result := DB.Where("parent_id = ?", folder.ID).Find(&folders) - - if result.Error == nil { - for i := 0; i < len(folders); i++ { - folders[i].Position = path.Join(folder.Position, folder.Name) - } - } - return folders, result.Error -} - -// GetRecursiveChildFolder 查找所有递归子目录,包括自身 -func GetRecursiveChildFolder(dirs []uint, uid uint, includeSelf bool) ([]Folder, error) { - folders := make([]Folder, 0, len(dirs)) - var err error - - var parFolders []Folder - result := DB.Where("owner_id = ? and id in (?)", uid, dirs).Find(&parFolders) - if result.Error != nil { - return folders, err - } - - // 整理父目录的ID - var parentIDs = make([]uint, 0, len(parFolders)) - for _, folder := range parFolders { - parentIDs = append(parentIDs, folder.ID) - } - - if includeSelf { - // 合并至最终结果 - folders = append(folders, parFolders...) - } - parFolders = []Folder{} - - // 递归查询子目录,最大递归65535次 - for i := 0; i < 65535; i++ { - - result = DB.Where("owner_id = ? and parent_id in (?)", uid, parentIDs).Find(&parFolders) - - // 查询结束条件 - if len(parFolders) == 0 { - break - } - - // 整理父目录的ID - parentIDs = make([]uint, 0, len(parFolders)) - for _, folder := range parFolders { - parentIDs = append(parentIDs, folder.ID) - } - - // 合并至最终结果 - folders = append(folders, parFolders...) - parFolders = []Folder{} - - } - - return folders, err -} - -// DeleteFolderByIDs 根据给定ID批量删除目录记录 -func DeleteFolderByIDs(ids []uint) error { - result := DB.Where("id in (?)", ids).Unscoped().Delete(&Folder{}) - return result.Error -} - -// GetFoldersByIDs 根据ID和用户查找所有目录 -func GetFoldersByIDs(ids []uint, uid uint) ([]Folder, error) { - var folders []Folder - result := DB.Where("id in (?) AND owner_id = ?", ids, uid).Find(&folders) - return folders, result.Error -} - -// MoveOrCopyFileTo 将此目录下的files移动或复制至dstFolder, -// 返回此操作新增的容量 -func (folder *Folder) MoveOrCopyFileTo(files []uint, dstFolder *Folder, isCopy bool) (uint64, error) { - // 已复制文件的总大小 - var copiedSize uint64 - - if isCopy { - // 检索出要复制的文件 - var originFiles = make([]File, 0, len(files)) - if err := DB.Where( - "id in (?) and user_id = ? and folder_id = ?", - files, - folder.OwnerID, - folder.ID, - ).Find(&originFiles).Error; err != nil { - return 0, err - } - - // 复制文件记录 - for _, oldFile := range originFiles { - if !oldFile.CanCopy() { - util.Log().Warning("Cannot copy file %q because it's being uploaded now, skipping...", oldFile.Name) - continue - } - - oldFile.Model = gorm.Model{} - oldFile.FolderID = dstFolder.ID - oldFile.UserID = dstFolder.OwnerID - - // webdav目标名重置 - if dstFolder.WebdavDstName != "" { - oldFile.Name = dstFolder.WebdavDstName - } - - if err := DB.Create(&oldFile).Error; err != nil { - return copiedSize, err - } - - copiedSize += oldFile.Size - } - - } else { - var updates = map[string]interface{}{ - "folder_id": dstFolder.ID, - } - // webdav目标名重置 - if dstFolder.WebdavDstName != "" { - updates["name"] = dstFolder.WebdavDstName - } - - // 更改顶级要移动文件的父目录指向 - err := DB.Model(File{}).Where( - "id in (?) and user_id = ? and folder_id = ?", - files, - folder.OwnerID, - folder.ID, - ). - Update(updates). - Error - if err != nil { - return 0, err - } - - } - - return copiedSize, nil - -} - -// CopyFolderTo 将此目录及其子目录及文件递归复制至dstFolder -// 返回此操作新增的容量 -func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint64, err error) { - // 列出所有子目录 - subFolders, err := GetRecursiveChildFolder([]uint{folderID}, folder.OwnerID, true) - if err != nil { - return 0, err - } - - // 抽离所有子目录的ID - var subFolderIDs = make([]uint, len(subFolders)) - for key, value := range subFolders { - subFolderIDs[key] = value.ID - } - - // 复制子目录 - var newIDCache = make(map[uint]uint) - for _, folder := range subFolders { - // 新的父目录指向 - var newID uint - // 顶级目录直接指向新的目的目录 - if folder.ID == folderID { - newID = dstFolder.ID - // webdav目标名重置 - if dstFolder.WebdavDstName != "" { - folder.Name = dstFolder.WebdavDstName - } - } else if IDCache, ok := newIDCache[*folder.ParentID]; ok { - newID = IDCache - } else { - util.Log().Warning("Failed to get parent folder %q", *folder.ParentID) - return size, errors.New("Failed to get parent folder") - } - - // 插入新的目录记录 - oldID := folder.ID - folder.Model = gorm.Model{} - folder.ParentID = &newID - folder.OwnerID = dstFolder.OwnerID - if err = DB.Create(&folder).Error; err != nil { - return size, err - } - // 记录新的ID以便其子目录使用 - newIDCache[oldID] = folder.ID - - } - - // 复制文件 - var originFiles = make([]File, 0, len(subFolderIDs)) - if err := DB.Where( - "user_id = ? and folder_id in (?)", - folder.OwnerID, - subFolderIDs, - ).Find(&originFiles).Error; err != nil { - return 0, err - } - - // 复制文件记录 - for _, oldFile := range originFiles { - if !oldFile.CanCopy() { - util.Log().Warning("Cannot copy file %q because it's being uploaded now, skipping...", oldFile.Name) - continue - } - - oldFile.Model = gorm.Model{} - oldFile.FolderID = newIDCache[oldFile.FolderID] - oldFile.UserID = dstFolder.OwnerID - if err := DB.Create(&oldFile).Error; err != nil { - return size, err - } - - size += oldFile.Size - } - - return size, nil - -} - -// MoveFolderTo 将folder目录下的dirs子目录复制或移动到dstFolder, -// 返回此过程中增加的容量 -func (folder *Folder) MoveFolderTo(dirs []uint, dstFolder *Folder) error { - - // 如果目标位置为待移动的目录,会导致 parent 为自己 - // 造成死循环且无法被除搜索以外的组件展示 - if folder.OwnerID == dstFolder.OwnerID && util.ContainsUint(dirs, dstFolder.ID) { - return errors.New("cannot move a folder into itself") - } - - var updates = map[string]interface{}{ - "parent_id": dstFolder.ID, - } - // webdav目标名重置 - if dstFolder.WebdavDstName != "" { - updates["name"] = dstFolder.WebdavDstName - } - - // 更改顶级要移动目录的父目录指向 - err := DB.Model(Folder{}).Where( - "id in (?) and owner_id = ? and parent_id = ?", - dirs, - folder.OwnerID, - folder.ID, - ).Update(updates).Error - - return err - -} - -// Rename 重命名目录 -func (folder *Folder) Rename(new string) error { - return DB.Model(&folder).UpdateColumn("name", new).Error -} - -/* - 实现 FileInfo.FileInfo 接口 - TODO 测试 -*/ - -func (folder *Folder) GetName() string { - return folder.Name -} - -func (folder *Folder) GetSize() uint64 { - return 0 -} -func (folder *Folder) ModTime() time.Time { - return folder.UpdatedAt -} -func (folder *Folder) IsDir() bool { - return true -} -func (folder *Folder) GetPosition() string { - return folder.Position -} diff --git a/models/folder_test.go b/models/folder_test.go deleted file mode 100644 index 90220cac..00000000 --- a/models/folder_test.go +++ /dev/null @@ -1,622 +0,0 @@ -package model - -import ( - "errors" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFolder_Create(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Name: "new folder", - } - - // 不存在,插入成功 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - fid, err := folder.Create() - asserts.NoError(err) - asserts.Equal(uint(5), fid) - asserts.NoError(mock.ExpectationsWereMet()) - - // 插入失败 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - fid, err = folder.Create() - asserts.NoError(err) - asserts.Equal(uint(1), fid) - asserts.NoError(mock.ExpectationsWereMet()) - - // 存在,直接返回 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5)) - fid, err = folder.Create() - asserts.NoError(err) - asserts.Equal(uint(5), fid) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFolder_GetChild(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ID: 5}, - OwnerID: 1, - Name: "/", - } - - // 目录存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(5, 1, "sub"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "sub")) - sub, err := folder.GetChild("sub") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(sub.Name, "sub") - asserts.Equal("/", sub.Position) - } - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(5, 1, "sub"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - sub, err := folder.GetChild("sub") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint(0), sub.ID) - - } -} - -func TestFolder_GetChildFolder(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Model: gorm.Model{ - ID: 1, - }, - Position: "/123", - Name: "456", - } - - // 找不到 - mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) - files, err := folder.GetChildFolder() - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找到了 - mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) - files, err = folder.GetChildFolder() - asserts.NoError(err) - asserts.Len(files, 2) - asserts.Equal("/123/456", files[0].Position) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetRecursiveChildFolderSQLite(t *testing.T) { - conf.DatabaseConfig.Type = "sqlite" - asserts := assert.New(t) - - // 测试目录结构 - // 1 - // 2 3 - // 4 5 6 - - // 查询第一层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "folder1"), - ) - // 查询第二层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(2, "folder2"). - AddRow(3, "folder3"), - ) - // 查询第三层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "folder4"). - AddRow(5, "folder5"). - AddRow(6, "folder6"), - ) - // 查询第四层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 4, 5, 6). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}), - ) - - folders, err := GetRecursiveChildFolder([]uint{1}, 1, true) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(folders, 6) -} - -func TestDeleteFolderByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFolderByIDs([]uint{1, 2, 3}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - err := DeleteFolderByIDs([]uint{1, 2, 3}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } -} - -func TestGetFoldersByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnError(errors.New("error")) - folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Len(folders, 0) - } - - // 部分找到 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } -} - -func TestFolder_MoveOrCopyFileTo(t *testing.T) { - asserts := assert.New(t) - // 当前目录 - folder := Folder{ - Model: gorm.Model{ID: 1}, - OwnerID: 1, - Name: "test", - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - Name: "dst", - } - - // 复制文件 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 3, - 1, - 1, - ).WillReturnRows( - sqlmock.NewRows([]string{"id", "size", "upload_session_id"}). - AddRow(1, 10, nil). - AddRow(2, 20, nil). - AddRow(2, 20, &folder.Name), - ) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2, 3}, - &dstFolder, - true, - ) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(30), storage) - } - - // 复制文件, 检索文件出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 1, - 1, - ).WillReturnError(errors.New("error")) - - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - true, - ) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(0), storage) - } - - // 复制文件,第二个文件插入出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 1, - 1, - ).WillReturnRows( - sqlmock.NewRows([]string{"id", "size"}). - AddRow(1, 10). - AddRow(2, 20), - ) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - true, - ) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(10), storage) - } - - // 移动文件 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1). - WillReturnResult(sqlmock.NewResult(1, 2)) - mock.ExpectCommit() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - false, - ) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(uint64(0), storage) - } - - // 移动文件 出错 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - false, - ) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), storage) - } -} - -func TestFolder_CopyFolderTo(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - // 父目录 - parFolder := Folder{ - Model: gorm.Model{ID: 9}, - OwnerID: 1, - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - } - - // 测试复制目录结构 - // test(2)(5) - // 1(3)(6) 2.txt - // 3(4)(7) 4.txt 5.txt(上传中) - - // 正常情况 成功 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "folder_id", "size", "upload_session_id"}). - AddRow(1, "2.txt", 2, 10, nil). - AddRow(2, "3.txt", 3, 20, nil). - AddRow(3, "5.txt", 3, 20, &dstFolder.Name), - ) - - // 复制子文件 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(uint64(30), size) - } - - // 递归查询失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnError(errors.New("error")) - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 父目录ID不存在 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 99)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 查询子文件失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnError(errors.New("error")) - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 复制文件 一个失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "folder_id", "size"}). - AddRow(1, "2.txt", 2, 10). - AddRow(2, "3.txt", 3, 20), - ) - - // 复制子文件 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(10), size) - } - -} - -func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - // 父目录 - parFolder := Folder{ - Model: gorm.Model{ID: 9}, - OwnerID: 1, - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - OwnerID: 1, - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 9). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := parFolder.MoveFolderTo([]uint{1, 2}, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 移动自己到自己内部,失败 - { - err := parFolder.MoveFolderTo([]uint{10, 2}, &dstFolder) - asserts.Error(err) - } -} - -func TestFolder_FileInfoInterface(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ - UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), - }, - Name: "test_name", - OwnerID: 0, - Position: "/test", - } - - name := folder.GetName() - asserts.Equal("test_name", name) - - size := folder.GetSize() - asserts.Equal(uint64(0), size) - - asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), folder.ModTime()) - asserts.True(folder.IsDir()) - asserts.Equal("/test", folder.GetPosition()) -} - -func TestTraceRoot(t *testing.T) { - asserts := assert.New(t) - var parentId uint - parentId = 5 - folder := Folder{ - ParentID: &parentId, - OwnerID: 1, - Name: "test_name", - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/")) - asserts.NoError(folder.TraceRoot()) - asserts.Equal("/parent", folder.Position) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 出现错误 - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0). - WillReturnError(errors.New("error")) - asserts.Error(folder.TraceRoot()) - asserts.Equal("parent", folder.Position) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFolder_Rename(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test_name", - OwnerID: 1, - Position: "/test", - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("test_name_new", 1). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := folder.Rename("test_name_new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 出现错误 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("test_name_new", 1). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := folder.Rename("test_name_new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} diff --git a/models/group_test.go b/models/group_test.go deleted file mode 100644 index 2f487ce2..00000000 --- a/models/group_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestGetGroupByID(t *testing.T) { - asserts := assert.New(t) - - //找到用户组时 - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - group, err := GetGroupByID(1) - asserts.NoError(err) - asserts.Equal(Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, group) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - group, err = GetGroupByID(1) - asserts.Error(err) - asserts.Equal(Group{}, group) -} - -func TestGroup_AfterFind(t *testing.T) { - asserts := assert.New(t) - - testCase := Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - } - err := testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{1}) - - testCase.Policies = "[1,2,3,4,5]" - err = testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{1, 2, 3, 4, 5}) - - testCase.Policies = "[1,2,3,4,5" - err = testCase.AfterFind() - asserts.Error(err) - - testCase.Policies = "[]" - err = testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{}) -} - -func TestGroup_BeforeSave(t *testing.T) { - asserts := assert.New(t) - group := Group{ - PolicyList: []uint{1, 2, 3}, - } - { - err := group.BeforeSave() - asserts.NoError(err) - asserts.Equal("[1,2,3]", group.Policies) - } - -} diff --git a/models/init.go b/models/init.go deleted file mode 100644 index a0920a92..00000000 --- a/models/init.go +++ /dev/null @@ -1,106 +0,0 @@ -package model - -import ( - "fmt" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - - _ "github.com/cloudreve/Cloudreve/v3/models/dialects" - _ "github.com/glebarez/go-sqlite" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" -) - -// DB 数据库链接单例 -var DB *gorm.DB - -// Init 初始化 MySQL 链接 -func Init() { - util.Log().Info("Initializing database connection...") - - var ( - db *gorm.DB - err error - confDBType string = conf.DatabaseConfig.Type - ) - - // 兼容已有配置中的 "sqlite3" 配置项 - if confDBType == "sqlite3" { - confDBType = "sqlite" - } - - if gin.Mode() == gin.TestMode { - // 测试模式下,使用内存数据库 - db, err = gorm.Open("sqlite", ":memory:") - } else { - switch confDBType { - case "UNSET", "sqlite": - // 未指定数据库或者明确指定为 sqlite 时,使用 SQLite 数据库 - db, err = gorm.Open("sqlite", util.RelativePath(conf.DatabaseConfig.DBFile)) - case "postgres": - db, err = gorm.Open(confDBType, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable", - conf.DatabaseConfig.Host, - conf.DatabaseConfig.User, - conf.DatabaseConfig.Password, - conf.DatabaseConfig.Name, - conf.DatabaseConfig.Port)) - case "mysql", "mssql": - var host string - if conf.DatabaseConfig.UnixSocket { - host = fmt.Sprintf("unix(%s)", - conf.DatabaseConfig.Host) - } else { - host = fmt.Sprintf("(%s:%d)", - conf.DatabaseConfig.Host, - conf.DatabaseConfig.Port) - } - - db, err = gorm.Open(confDBType, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local", - conf.DatabaseConfig.User, - conf.DatabaseConfig.Password, - host, - conf.DatabaseConfig.Name, - conf.DatabaseConfig.Charset)) - default: - util.Log().Panic("Unsupported database type %q.", confDBType) - } - } - - //db.SetLogger(util.Log()) - if err != nil { - util.Log().Panic("Failed to connect to database: %s", err) - } - - // 处理表前缀 - gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string { - return conf.DatabaseConfig.TablePrefix + defaultTableName - } - - // Debug模式下,输出所有 SQL 日志 - if conf.SystemConfig.Debug { - db.LogMode(true) - } else { - db.LogMode(false) - } - - //设置连接池 - db.DB().SetMaxIdleConns(50) - if confDBType == "sqlite" || confDBType == "UNSET" { - db.DB().SetMaxOpenConns(1) - } else { - db.DB().SetMaxOpenConns(100) - } - - //超时 - db.DB().SetConnMaxLifetime(time.Second * 30) - - DB = db - - //执行迁移 - migration() -} diff --git a/models/migration.go b/models/migration.go deleted file mode 100644 index fad6a766..00000000 --- a/models/migration.go +++ /dev/null @@ -1,218 +0,0 @@ -package model - -import ( - "context" - "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/fatih/color" - "github.com/hashicorp/go-version" - "github.com/jinzhu/gorm" - "sort" - "strings" -) - -// 是否需要迁移 -func needMigration() bool { - var setting Setting - return DB.Where("name = ?", "db_version_"+conf.RequiredDBVersion).First(&setting).Error != nil -} - -// 执行数据迁移 -func migration() { - // 确认是否需要执行迁移 - if !needMigration() { - util.Log().Info("Database version fulfilled, skip schema migration.") - return - - } - - util.Log().Info("Start initializing database schema...") - - // 清除所有缓存 - if instance, ok := cache.Store.(*cache.RedisStore); ok { - instance.DeleteAll() - } - - // 自动迁移模式 - if conf.DatabaseConfig.Type == "mysql" { - DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") - } - - DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{}, - &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{}) - - // 创建初始存储策略 - addDefaultPolicy() - - // 创建初始用户组 - addDefaultGroups() - - // 创建初始管理员账户 - addDefaultUser() - - // 创建初始节点 - addDefaultNode() - - // 向设置数据表添加初始设置 - addDefaultSettings() - - // 执行数据库升级脚本 - execUpgradeScripts() - - util.Log().Info("Finish initializing database schema.") - -} - -func addDefaultPolicy() { - _, err := GetPolicyByID(uint(1)) - // 未找到初始存储策略时,则创建 - if gorm.IsRecordNotFoundError(err) { - defaultPolicy := Policy{ - Name: "Default storage policy", - Type: "local", - MaxSize: 0, - AutoRename: true, - DirNameRule: "uploads/{uid}/{path}", - FileNameRule: "{uid}_{randomkey8}_{originname}", - IsOriginLinkEnable: false, - OptionsSerialized: PolicyOption{ - ChunkSize: 25 << 20, // 25MB - }, - } - if err := DB.Create(&defaultPolicy).Error; err != nil { - util.Log().Panic("Failed to create default storage policy: %s", err) - } - } -} - -func addDefaultSettings() { - for _, value := range defaultSettings { - DB.Where(Setting{Name: value.Name}).Create(&value) - } -} - -func addDefaultGroups() { - _, err := GetGroupByID(1) - // 未找到初始管理组时,则创建 - if gorm.IsRecordNotFoundError(err) { - defaultAdminGroup := Group{ - Name: "Admin", - PolicyList: []uint{1}, - MaxStorage: 1 * 1024 * 1024 * 1024, - ShareEnabled: true, - WebDAVEnabled: true, - OptionsSerialized: GroupOption{ - ArchiveDownload: true, - ArchiveTask: true, - ShareDownload: true, - Aria2: true, - SourceBatchSize: 1000, - Aria2BatchSize: 50, - RedirectedSource: true, - AdvanceDelete: true, - }, - } - if err := DB.Create(&defaultAdminGroup).Error; err != nil { - util.Log().Panic("Failed to create admin user group: %s", err) - } - } - - err = nil - _, err = GetGroupByID(2) - // 未找到初始注册会员时,则创建 - if gorm.IsRecordNotFoundError(err) { - defaultAdminGroup := Group{ - Name: "User", - PolicyList: []uint{1}, - MaxStorage: 1 * 1024 * 1024 * 1024, - ShareEnabled: true, - WebDAVEnabled: true, - OptionsSerialized: GroupOption{ - ShareDownload: true, - SourceBatchSize: 10, - Aria2BatchSize: 1, - RedirectedSource: true, - }, - } - if err := DB.Create(&defaultAdminGroup).Error; err != nil { - util.Log().Panic("Failed to create initial user group: %s", err) - } - } - - err = nil - _, err = GetGroupByID(3) - // 未找到初始游客用户组时,则创建 - if gorm.IsRecordNotFoundError(err) { - defaultAdminGroup := Group{ - Name: "Anonymous", - PolicyList: []uint{}, - Policies: "[]", - OptionsSerialized: GroupOption{ - ShareDownload: true, - }, - } - if err := DB.Create(&defaultAdminGroup).Error; err != nil { - util.Log().Panic("Failed to create anonymous user group: %s", err) - } - } -} - -func addDefaultUser() { - _, err := GetUserByID(1) - password := util.RandStringRunes(8) - - // 未找到初始用户时,则创建 - if gorm.IsRecordNotFoundError(err) { - defaultUser := NewUser() - defaultUser.Email = "admin@cloudreve.org" - defaultUser.Nick = "admin" - defaultUser.Status = Active - defaultUser.GroupID = 1 - err := defaultUser.SetPassword(password) - if err != nil { - util.Log().Panic("Failed to create password: %s", err) - } - if err := DB.Create(&defaultUser).Error; err != nil { - util.Log().Panic("Failed to create initial root user: %s", err) - } - - c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold) - util.Log().Info("Admin user name: " + c.Sprint("admin@cloudreve.org")) - util.Log().Info("Admin password: " + c.Sprint(password)) - } -} - -func addDefaultNode() { - _, err := GetNodeByID(1) - - if gorm.IsRecordNotFoundError(err) { - defaultAdminGroup := Node{ - Name: "Master (Local machine)", - Status: NodeActive, - Type: MasterNodeType, - Aria2OptionsSerialized: Aria2Option{ - Interval: 10, - Timeout: 10, - }, - } - if err := DB.Create(&defaultAdminGroup).Error; err != nil { - util.Log().Panic("Failed to create initial node: %s", err) - } - } -} - -func execUpgradeScripts() { - s := invoker.ListPrefix("UpgradeTo") - versions := make([]*version.Version, len(s)) - for i, raw := range s { - v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo")) - versions[i] = v - } - sort.Sort(version.Collection(versions)) - - for i := 0; i < len(versions); i++ { - invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background()) - } -} diff --git a/models/migration_test.go b/models/migration_test.go deleted file mode 100644 index 7c9d673a..00000000 --- a/models/migration_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package model - -import ( - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestMigration(t *testing.T) { - asserts := assert.New(t) - conf.DatabaseConfig.Type = "sqlite" - DB, _ = gorm.Open("sqlite", ":memory:") - - asserts.NotPanics(func() { - migration() - }) - conf.DatabaseConfig.Type = "mysql" - DB = mockDB -} diff --git a/models/node_test.go b/models/node_test.go deleted file mode 100644 index de1757f6..00000000 --- a/models/node_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestGetNodeByID(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetNodeByID(1) - a.NoError(err) - a.EqualValues(1, res.ID) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestGetNodesByStatus(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive)) - res, err := GetNodesByStatus(NodeActive) - a.NoError(err) - a.Len(res, 1) - a.EqualValues(NodeActive, res[0].Status) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestNode_AfterFind(t *testing.T) { - a := assert.New(t) - node := &Node{} - - // No aria2 options - { - a.NoError(node.AfterFind()) - } - - // with aria2 options - { - node.Aria2Options = `{"timeout":1}` - a.NoError(node.AfterFind()) - a.Equal(1, node.Aria2OptionsSerialized.Timeout) - } -} - -func TestNode_BeforeSave(t *testing.T) { - a := assert.New(t) - node := &Node{} - - node.Aria2OptionsSerialized.Timeout = 1 - a.NoError(node.BeforeSave()) - a.Contains(node.Aria2Options, "1") -} - -func TestNode_SetStatus(t *testing.T) { - a := assert.New(t) - node := &Node{} - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(node.SetStatus(NodeActive)) - a.Equal(NodeActive, node.Status) - a.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/policy.go b/models/policy.go deleted file mode 100644 index 11d8e4b6..00000000 --- a/models/policy.go +++ /dev/null @@ -1,243 +0,0 @@ -package model - -import ( - "encoding/gob" - "encoding/json" - "github.com/gofrs/uuid" - "github.com/samber/lo" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// Policy 存储策略 -type Policy struct { - // 表字段 - gorm.Model - Name string - Type string - Server string - BucketName string - IsPrivate bool - BaseURL string - AccessKey string `gorm:"type:text"` - SecretKey string `gorm:"type:text"` - MaxSize uint64 - AutoRename bool - DirNameRule string - FileNameRule string - IsOriginLinkEnable bool - Options string `gorm:"type:text"` - - // 数据库忽略字段 - OptionsSerialized PolicyOption `gorm:"-"` - MasterID string `gorm:"-"` -} - -// PolicyOption 非公有的存储策略属性 -type PolicyOption struct { - // Upyun访问Token - Token string `json:"token"` - // 允许的文件扩展名 - FileType []string `json:"file_type"` - // MimeType - MimeType string `json:"mimetype"` - // OauthRedirect Oauth 重定向地址 - OauthRedirect string `json:"od_redirect,omitempty"` - // OdProxy Onedrive 反代地址 - OdProxy string `json:"od_proxy,omitempty"` - // OdDriver OneDrive 驱动器定位符 - OdDriver string `json:"od_driver,omitempty"` - // Region 区域代码 - Region string `json:"region,omitempty"` - // ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段 - ServerSideEndpoint string `json:"server_side_endpoint,omitempty"` - // 分片上传的分片大小 - ChunkSize uint64 `json:"chunk_size,omitempty"` - // 分片上传时是否需要预留空间 - PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"` - // 每秒对存储端的 API 请求上限 - TPSLimit float64 `json:"tps_limit,omitempty"` - // 每秒 API 请求爆发上限 - TPSLimitBurst int `json:"tps_limit_burst,omitempty"` - // Set this to `true` to force the request to use path-style addressing, - // i.e., `http://s3.amazonaws.com/BUCKET/KEY ` - S3ForcePathStyle bool `json:"s3_path_style"` - // File extensions that support thumbnail generation using native policy API. - ThumbExts []string `json:"thumb_exts,omitempty"` -} - -func init() { - // 注册缓存用到的复杂结构 - gob.Register(Policy{}) -} - -// GetPolicyByID 用ID获取存储策略 -func GetPolicyByID(ID interface{}) (Policy, error) { - // 尝试读取缓存 - cacheKey := "policy_" + strconv.Itoa(int(ID.(uint))) - if policy, ok := cache.Get(cacheKey); ok { - return policy.(Policy), nil - } - - var policy Policy - result := DB.First(&policy, ID) - - // 写入缓存 - if result.Error == nil { - _ = cache.Set(cacheKey, policy, -1) - } - - return policy, result.Error -} - -// AfterFind 找到存储策略后的钩子 -func (policy *Policy) AfterFind() (err error) { - // 解析存储策略设置到OptionsSerialized - if policy.Options != "" { - err = json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized) - } - if policy.OptionsSerialized.FileType == nil { - policy.OptionsSerialized.FileType = []string{} - } - - return err -} - -// BeforeSave Save策略前的钩子 -func (policy *Policy) BeforeSave() (err error) { - err = policy.SerializeOptions() - return err -} - -// SerializeOptions 将序列后的Option写入到数据库字段 -func (policy *Policy) SerializeOptions() (err error) { - optionsValue, err := json.Marshal(&policy.OptionsSerialized) - policy.Options = string(optionsValue) - return err -} - -// GeneratePath 生成存储文件的路径 -func (policy *Policy) GeneratePath(uid uint, origin string) string { - dirRule := policy.DirNameRule - replaceTable := map[string]string{ - "{randomkey16}": util.RandStringRunes(16), - "{randomkey8}": util.RandStringRunes(8), - "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), - "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), - "{uid}": strconv.Itoa(int(uid)), - "{datetime}": time.Now().Format("20060102150405"), - "{date}": time.Now().Format("20060102"), - "{year}": time.Now().Format("2006"), - "{month}": time.Now().Format("01"), - "{day}": time.Now().Format("02"), - "{hour}": time.Now().Format("15"), - "{minute}": time.Now().Format("04"), - "{second}": time.Now().Format("05"), - "{path}": origin + "/", - } - dirRule = util.Replace(replaceTable, dirRule) - return path.Clean(dirRule) -} - -// GenerateFileName 生成存储文件名 -func (policy *Policy) GenerateFileName(uid uint, origin string) string { - // 未开启自动重命名时,直接返回原始文件名 - if !policy.AutoRename { - return origin - } - - fileRule := policy.FileNameRule - - replaceTable := map[string]string{ - "{randomkey16}": util.RandStringRunes(16), - "{randomkey8}": util.RandStringRunes(8), - "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), - "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), - "{uid}": strconv.Itoa(int(uid)), - "{datetime}": time.Now().Format("20060102150405"), - "{date}": time.Now().Format("20060102"), - "{year}": time.Now().Format("2006"), - "{month}": time.Now().Format("01"), - "{day}": time.Now().Format("02"), - "{hour}": time.Now().Format("15"), - "{minute}": time.Now().Format("04"), - "{second}": time.Now().Format("05"), - "{originname}": origin, - "{ext}": filepath.Ext(origin), - "{originname_without_ext}": strings.TrimSuffix(origin, filepath.Ext(origin)), - "{uuid}": uuid.Must(uuid.NewV4()).String(), - } - - fileRule = util.Replace(replaceTable, fileRule) - return fileRule -} - -// IsDirectlyPreview 返回此策略下文件是否可以直接预览(不需要重定向) -func (policy *Policy) IsDirectlyPreview() bool { - return policy.Type == "local" -} - -// IsTransitUpload 返回此策略上传给定size文件时是否需要服务端中转 -func (policy *Policy) IsTransitUpload(size uint64) bool { - return policy.Type == "local" -} - -// IsThumbGenerateNeeded 返回此策略是否需要在上传后生成缩略图 -func (policy *Policy) IsThumbGenerateNeeded() bool { - return policy.Type == "local" -} - -// IsUploadPlaceholderWithSize 返回此策略创建上传会话时是否需要预留空间 -func (policy *Policy) IsUploadPlaceholderWithSize() bool { - if policy.Type == "remote" { - return true - } - - if util.ContainsString([]string{"onedrive", "oss", "qiniu", "cos", "s3"}, policy.Type) { - return policy.OptionsSerialized.PlaceholderWithSize - } - - return false -} - -// CanStructureBeListed 返回存储策略是否能被前台列物理目录 -func (policy *Policy) CanStructureBeListed() bool { - return policy.Type != "local" && policy.Type != "remote" -} - -// SaveAndClearCache 更新并清理缓存 -func (policy *Policy) SaveAndClearCache() error { - err := DB.Save(policy).Error - policy.ClearCache() - return err -} - -// SaveAndClearCache 更新并清理缓存 -func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error { - err := DB.Model(policy).UpdateColumn("access_key", s).Error - policy.ClearCache() - return err -} - -// ClearCache 清空policy缓存 -func (policy *Policy) ClearCache() { - cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_") -} - -// CouldProxyThumb return if proxy thumbs is allowed for this policy. -func (policy *Policy) CouldProxyThumb() bool { - if policy.Type == "local" || !IsTrueVal(GetSettingByName("thumb_proxy_enabled")) { - return false - } - - allowed := make([]uint, 0) - _ = json.Unmarshal([]byte(GetSettingByName("thumb_proxy_policy")), &allowed) - return lo.Contains[uint](allowed, policy.ID) -} diff --git a/models/policy_test.go b/models/policy_test.go deleted file mode 100644 index f7d4e747..00000000 --- a/models/policy_test.go +++ /dev/null @@ -1,269 +0,0 @@ -package model - -import ( - "encoding/json" - "strconv" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestGetPolicyByID(t *testing.T) { - asserts := assert.New(t) - - cache.Deletes([]string{"22", "23"}, "policy_") - // 缓存未命中 - { - rows := sqlmock.NewRows([]string{"name", "type", "options"}). - AddRow("默认存储策略", "local", "{\"od_redirect\":\"123\"}") - mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) - policy, err := GetPolicyByID(uint(22)) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) - - rows = sqlmock.NewRows([]string{"name", "type", "options"}) - mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) - policy, err = GetPolicyByID(uint(23)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 命中 - { - policy, err := GetPolicyByID(uint(22)) - asserts.NoError(err) - asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) - - } - -} - -func TestPolicy_BeforeSave(t *testing.T) { - asserts := assert.New(t) - - testPolicy := Policy{ - OptionsSerialized: PolicyOption{ - OauthRedirect: "123", - }, - } - expected, _ := json.Marshal(testPolicy.OptionsSerialized) - err := testPolicy.BeforeSave() - asserts.NoError(err) - asserts.Equal(string(expected), testPolicy.Options) - -} - -func TestPolicy_GeneratePath(t *testing.T) { - asserts := assert.New(t) - testPolicy := Policy{} - - testPolicy.DirNameRule = "{randomkey16}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 16) - - testPolicy.DirNameRule = "{randomkey8}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 8) - - testPolicy.DirNameRule = "{timestamp}" - asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.FormatInt(time.Now().Unix(), 10)) - - testPolicy.DirNameRule = "{uid}" - asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.Itoa(int(1))) - - testPolicy.DirNameRule = "{datetime}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 14) - - testPolicy.DirNameRule = "{date}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 8) - - testPolicy.DirNameRule = "123{date}ss{datetime}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 27) - - testPolicy.DirNameRule = "/1/{path}/456" - asserts.Condition(func() (success bool) { - res := testPolicy.GeneratePath(1, "/23") - return res == "/1/23/456" || res == "\\1\\23\\456" - }) - -} - -func TestPolicy_GenerateFileName(t *testing.T) { - asserts := assert.New(t) - // 重命名关闭 - { - testPolicy := Policy{ - AutoRename: false, - } - testPolicy.FileNameRule = "{randomkey16}" - asserts.Equal("123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "oss" - asserts.Equal("origin", testPolicy.GenerateFileName(1, "origin")) - } - - // 重命名开启 - { - testPolicy := Policy{ - AutoRename: true, - } - - testPolicy.FileNameRule = "{randomkey16}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16) - - testPolicy.FileNameRule = "{randomkey8}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) - - testPolicy.FileNameRule = "{timestamp}" - asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.FormatInt(time.Now().Unix(), 10)) - - testPolicy.FileNameRule = "{uid}" - asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.Itoa(int(1))) - - testPolicy.FileNameRule = "{datetime}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 14) - - testPolicy.FileNameRule = "{date}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) - - testPolicy.FileNameRule = "123{date}ss{datetime}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 27) - - testPolicy.FileNameRule = "{originname_without_ext}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 3) - - testPolicy.FileNameRule = "{originname_without_ext}_{randomkey8}{ext}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16) - - // 支持{originname}的策略 - testPolicy.Type = "local" - testPolicy.FileNameRule = "123{originname}" - asserts.Equal("123123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "qiniu" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "oss" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "upyun" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "qiniu" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "local" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123", testPolicy.GenerateFileName(1, "")) - - testPolicy.Type = "local" - testPolicy.FileNameRule = "{ext}123{uuid}" - asserts.Contains(testPolicy.GenerateFileName(1, "123.txt"), ".txt123") - } - -} - -func TestPolicy_IsDirectlyPreview(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Type: "local"} - asserts.True(policy.IsDirectlyPreview()) - policy.Type = "remote" - asserts.False(policy.IsDirectlyPreview()) -} - -func TestPolicy_ClearCache(t *testing.T) { - asserts := assert.New(t) - cache.Set("policy_202", 1, 0) - policy := Policy{Model: gorm.Model{ID: 202}} - policy.ClearCache() - _, ok := cache.Get("policy_202") - asserts.False(ok) -} - -func TestPolicy_UpdateAccessKey(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Model: gorm.Model{ID: 202}} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - policy.AccessKey = "123" - err := policy.SaveAndClearCache() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) -} - -func TestPolicy_Props(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Type: "onedrive"} - policy.OptionsSerialized.PlaceholderWithSize = true - asserts.False(policy.IsThumbGenerateNeeded()) - asserts.False(policy.IsTransitUpload(4)) - asserts.False(policy.IsTransitUpload(5 * 1024 * 1024)) - asserts.True(policy.CanStructureBeListed()) - asserts.True(policy.IsUploadPlaceholderWithSize()) - policy.Type = "local" - asserts.True(policy.IsThumbGenerateNeeded()) - asserts.False(policy.CanStructureBeListed()) - asserts.False(policy.IsUploadPlaceholderWithSize()) - policy.Type = "remote" - asserts.True(policy.IsUploadPlaceholderWithSize()) -} - -func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) { - a := assert.New(t) - cache.Set("policy_1331", Policy{}, 3600) - p := &Policy{} - p.ID = 1331 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(p.UpdateAccessKeyAndClearCache("ak")) - a.NoError(mock.ExpectationsWereMet()) - _, ok := cache.Get("policy_1331") - a.False(ok) -} - -func TestPolicy_CouldProxyThumb(t *testing.T) { - a := assert.New(t) - p := &Policy{Type: "local"} - - // local policy - { - a.False(p.CouldProxyThumb()) - } - - // feature not enabled - { - p.Type = "remote" - cache.Set("setting_thumb_proxy_enabled", "0", 0) - a.False(p.CouldProxyThumb()) - } - - // list not contain current policy - { - p.ID = 2 - cache.Set("setting_thumb_proxy_enabled", "1", 0) - cache.Set("setting_thumb_proxy_policy", "[1]", 0) - a.False(p.CouldProxyThumb()) - } - - // enabled - { - p.ID = 2 - cache.Set("setting_thumb_proxy_enabled", "1", 0) - cache.Set("setting_thumb_proxy_policy", "[2]", 0) - a.True(p.CouldProxyThumb()) - } - - cache.Deletes([]string{"thumb_proxy_enabled", "thumb_proxy_policy"}, "setting_") -} diff --git a/models/scripts/init.go b/models/scripts/init.go deleted file mode 100644 index 7c375bf4..00000000 --- a/models/scripts/init.go +++ /dev/null @@ -1,9 +0,0 @@ -package scripts - -import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" - -func Init() { - invoker.Register("ResetAdminPassword", ResetAdminPassword(0)) - invoker.Register("CalibrateUserStorage", UserStorageCalibration(0)) - invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0)) -} diff --git a/models/scripts/invoker/invoker.go b/models/scripts/invoker/invoker.go deleted file mode 100644 index b55b1e90..00000000 --- a/models/scripts/invoker/invoker.go +++ /dev/null @@ -1,38 +0,0 @@ -package invoker - -import ( - "context" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "strings" -) - -type DBScript interface { - Run(ctx context.Context) -} - -var availableScripts = make(map[string]DBScript) - -func RunDBScript(name string, ctx context.Context) error { - if script, ok := availableScripts[name]; ok { - util.Log().Info("Start executing database script %q.", name) - script.Run(ctx) - return nil - } - - return fmt.Errorf("Database script %q not exist.", name) -} - -func Register(name string, script DBScript) { - availableScripts[name] = script -} - -func ListPrefix(prefix string) []string { - var scripts []string - for name := range availableScripts { - if strings.HasPrefix(name, prefix) { - scripts = append(scripts, name) - } - } - return scripts -} diff --git a/models/scripts/invoker/invoker_test.go b/models/scripts/invoker/invoker_test.go deleted file mode 100644 index 36651eb4..00000000 --- a/models/scripts/invoker/invoker_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package invoker - -import ( - "context" - "github.com/stretchr/testify/assert" - "testing" -) - -type TestScript int - -func (script TestScript) Run(ctx context.Context) { - -} - -func TestRunDBScript(t *testing.T) { - asserts := assert.New(t) - Register("test", TestScript(0)) - - // 不存在 - { - asserts.Error(RunDBScript("else", context.Background())) - } - - // 存在 - { - asserts.NoError(RunDBScript("test", context.Background())) - } -} - -func TestListPrefix(t *testing.T) { - asserts := assert.New(t) - Register("U1", TestScript(0)) - Register("U2", TestScript(0)) - Register("U3", TestScript(0)) - Register("P1", TestScript(0)) - - res := ListPrefix("U") - asserts.Len(res, 3) -} diff --git a/models/scripts/reset.go b/models/scripts/reset.go deleted file mode 100644 index 1f6bf08d..00000000 --- a/models/scripts/reset.go +++ /dev/null @@ -1,31 +0,0 @@ -package scripts - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/fatih/color" -) - -type ResetAdminPassword int - -// Run 运行脚本从社区版升级至 Pro 版 -func (script ResetAdminPassword) Run(ctx context.Context) { - // 查找用户 - user, err := model.GetUserByID(1) - if err != nil { - util.Log().Panic("Initial admin user not exist: %s", err) - } - - // 生成密码 - password := util.RandStringRunes(8) - - // 更改为新密码 - user.SetPassword(password) - if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil { - util.Log().Panic("Failed to update password: %s", err) - } - - c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold) - util.Log().Info("Initial admin user password changed to:" + c.Sprint(password)) -} diff --git a/models/scripts/reset_test.go b/models/scripts/reset_test.go deleted file mode 100644 index ffacb28b..00000000 --- a/models/scripts/reset_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package scripts - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestResetAdminPassword_Run(t *testing.T) { - asserts := assert.New(t) - script := ResetAdminPassword(0) - - // 初始用户不存在 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"})) - asserts.Panics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 密码更新失败 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - asserts.Panics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NotPanics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/scripts/storage.go b/models/scripts/storage.go deleted file mode 100644 index 0d436b9f..00000000 --- a/models/scripts/storage.go +++ /dev/null @@ -1,33 +0,0 @@ -package scripts - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -type UserStorageCalibration int - -type storageResult struct { - Total uint64 -} - -// Run 运行脚本校准所有用户容量 -func (script UserStorageCalibration) Run(ctx context.Context) { - // 列出所有用户 - var res []model.User - model.DB.Model(&model.User{}).Find(&res) - - // 逐个检查容量 - for _, user := range res { - // 计算正确的容量 - var total storageResult - model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total) - // 更新用户的容量 - if user.Storage != total.Total { - util.Log().Info("Calibrate used storage for user %q, from %d to %d.", user.Email, - user.Storage, total.Total) - } - model.DB.Model(&user).Update("storage", total.Total) - } -} diff --git a/models/scripts/storage_test.go b/models/scripts/storage_test.go deleted file mode 100644 index 746f0c00..00000000 --- a/models/scripts/storage_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package scripts - -import ( - "context" - "database/sql" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -var mock sqlmock.Sqlmock -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - mockDB = model.DB - defer db.Close() - m.Run() -} - -func TestUserStorageCalibration_Run(t *testing.T) { - asserts := assert.New(t) - script := UserStorageCalibration(0) - - // 容量异常 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 容量正常 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/scripts/upgrade.go b/models/scripts/upgrade.go deleted file mode 100644 index 717a72eb..00000000 --- a/models/scripts/upgrade.go +++ /dev/null @@ -1,43 +0,0 @@ -package scripts - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "strconv" -) - -type UpgradeTo340 int - -// Run upgrade from older version to 3.4.0 -func (script UpgradeTo340) Run(ctx context.Context) { - // 取回老版本 aria2 设定 - old := model.GetSettingByType([]string{"aria2"}) - if len(old) == 0 { - return - } - - // 写入到新版本的节点设定 - n, err := model.GetNodeByID(1) - if err != nil { - util.Log().Error("找不到主机节点, %s", err) - } - - n.Aria2Enabled = old["aria2_rpcurl"] != "" - n.Aria2OptionsSerialized.Options = old["aria2_options"] - n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"] - - interval, err := strconv.Atoi(old["aria2_interval"]) - if err != nil { - interval = 10 - } - n.Aria2OptionsSerialized.Interval = interval - n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"] - n.Aria2OptionsSerialized.Token = old["aria2_token"] - if err := model.DB.Save(&n).Error; err != nil { - util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err) - } else { - model.DB.Where("type = ?", "aria2").Delete(model.Setting{}) - util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式") - } -} diff --git a/models/scripts/upgrade_test.go b/models/scripts/upgrade_test.go deleted file mode 100644 index 8f7adbaf..00000000 --- a/models/scripts/upgrade_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package scripts - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestUpgradeTo340_Run(t *testing.T) { - a := assert.New(t) - script := UpgradeTo340(0) - - // skip - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"})) - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // node not found - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1")) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"})) - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // success - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). - AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). - AddRow("aria2_interval", "expected_aria2_interval"). - AddRow("aria2_temp_path", "expected_aria2_temp_path"). - AddRow("aria2_token", "expected_aria2_token"). - AddRow("aria2_options", "{}")) - - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // failed - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). - AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). - AddRow("aria2_interval", "expected_aria2_interval"). - AddRow("aria2_temp_path", "expected_aria2_temp_path"). - AddRow("aria2_token", "expected_aria2_token"). - AddRow("aria2_options", "{}")) - - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/setting.go b/models/setting.go deleted file mode 100644 index 0bbcf68d..00000000 --- a/models/setting.go +++ /dev/null @@ -1,110 +0,0 @@ -package model - -import ( - "net/url" - "strconv" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" -) - -// Setting 系统设置模型 -type Setting struct { - gorm.Model - Type string `gorm:"not null"` - Name string `gorm:"unique;not null;index:setting_key"` - Value string `gorm:"size:‎65535"` -} - -// IsTrueVal 返回设置的值是否为真 -func IsTrueVal(val string) bool { - return val == "1" || val == "true" -} - -// GetSettingByName 用 Name 获取设置值 -func GetSettingByName(name string) string { - return GetSettingByNameFromTx(DB, name) -} - -// GetSettingByNameFromTx 用 Name 获取设置值,使用事务 -func GetSettingByNameFromTx(tx *gorm.DB, name string) string { - var setting Setting - - // 优先从缓存中查找 - cacheKey := "setting_" + name - if optionValue, ok := cache.Get(cacheKey); ok { - return optionValue.(string) - } - - // 尝试数据库中查找 - if tx == nil { - tx = DB - if tx == nil { - return "" - } - } - - result := tx.Where("name = ?", name).First(&setting) - if result.Error == nil { - _ = cache.Set(cacheKey, setting.Value, -1) - return setting.Value - } - - return "" -} - -// GetSettingByNameWithDefault 用 Name 获取设置值, 取不到时使用缺省值 -func GetSettingByNameWithDefault(name, fallback string) string { - res := GetSettingByName(name) - if res == "" { - return fallback - } - return res -} - -// GetSettingByNames 用多个 Name 获取设置值 -func GetSettingByNames(names ...string) map[string]string { - var queryRes []Setting - res, miss := cache.GetSettings(names, "setting_") - - if len(miss) > 0 { - DB.Where("name IN (?)", miss).Find(&queryRes) - for _, setting := range queryRes { - res[setting.Name] = setting.Value - } - } - - _ = cache.SetSettings(res, "setting_") - return res -} - -// GetSettingByType 获取一个或多个分组的所有设置值 -func GetSettingByType(types []string) map[string]string { - var queryRes []Setting - res := make(map[string]string) - - DB.Where("type IN (?)", types).Find(&queryRes) - for _, setting := range queryRes { - res[setting.Name] = setting.Value - } - - return res -} - -// GetSiteURL 获取站点地址 -func GetSiteURL() *url.URL { - base, err := url.Parse(GetSettingByName("siteURL")) - if err != nil { - base, _ = url.Parse("https://cloudreve.org") - } - return base -} - -// GetIntSetting 获取整形设置值,如果转换失败则返回默认值defaultVal -func GetIntSetting(key string, defaultVal int) int { - res, err := strconv.Atoi(GetSettingByName(key)) - if err != nil { - return defaultVal - } - return res -} diff --git a/models/setting_test.go b/models/setting_test.go deleted file mode 100644 index 96fc5e03..00000000 --- a/models/setting_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package model - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - DB, _ = gorm.Open("mysql", db) - mockDB = DB - defer db.Close() - m.Run() -} - -func TestGetSettingByType(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByType([]string{"basic"}) - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - - rows = sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic2") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByType([]string{"basic", "basic2"}) - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - - //找不到 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByType([]string{"basic233"}) - asserts.Equal(map[string]string{}, settings) -} - -func TestGetSettingByNameWithDefault(t *testing.T) { - a := assert.New(t) - - rows := sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByNameWithDefault("123", "123321") - a.Equal("123321", settings) -} - -func TestGetSettingByNames(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByNames("siteName", "siteDes") - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - //找到其中一个设置时 - rows = sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName2", "Cloudreve", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByNames("siteName2", "siteDes2333") - asserts.Equal(map[string]string{ - "siteName2": "Cloudreve", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - //找不到设置时 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByNames("siteName2333", "siteDes2333") - asserts.Equal(map[string]string{}, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - // 一个设置命中缓存 - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WithArgs("siteDes2").WillReturnRows(sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteDes2", "Cloudreve2", "basic")) - settings = GetSettingByNames("siteName", "siteDes2") - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes2": "Cloudreve2", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -// TestGetSettingByName 测试GetSettingByName -func TestGetSettingByName(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - - siteName := GetSettingByName("siteName") - asserts.Equal("Cloudreve", siteName) - asserts.NoError(mock.ExpectationsWereMet()) - - // 第二次查询应返回缓存内容 - siteNameCache := GetSettingByName("siteName") - asserts.Equal("Cloudreve", siteNameCache) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找不到设置 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - - siteName = GetSettingByName("siteName not exist") - asserts.Equal("", siteName) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestIsTrueVal(t *testing.T) { - asserts := assert.New(t) - - asserts.True(IsTrueVal("1")) - asserts.True(IsTrueVal("true")) - asserts.False(IsTrueVal("0")) - asserts.False(IsTrueVal("false")) -} - -func TestGetSiteURL(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - - mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://drive.cloudreve.org")) - siteURL := GetSiteURL() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("https://drive.cloudreve.org", siteURL.String()) - } - - // 失败 返回默认值 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - - mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, ":][\\/\\]sdf")) - siteURL := GetSiteURL() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("https://cloudreve.org", siteURL.String()) - } -} - -func TestGetIntSetting(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - cache.Set("setting_TestGetIntSetting", "10", 0) - res := GetIntSetting("TestGetIntSetting", 20) - asserts.Equal(10, res) - } - - // 使用默认值 - { - res := GetIntSetting("TestGetIntSetting_2", 20) - asserts.Equal(20, res) - } - -} diff --git a/models/share.go b/models/share.go deleted file mode 100644 index 750eb48e..00000000 --- a/models/share.go +++ /dev/null @@ -1,246 +0,0 @@ -package model - -import ( - "errors" - "fmt" - "strings" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" -) - -// Share 分享模型 -type Share struct { - gorm.Model - Password string // 分享密码,空值为非加密分享 - IsDir bool // 原始资源是否为目录 - UserID uint // 创建用户ID - SourceID uint // 原始资源ID - Views int // 浏览数 - Downloads int // 下载数 - RemainDownloads int // 剩余下载配额,负值标识无限制 - Expires *time.Time // 过期时间,空值表示无过期时间 - PreviewEnabled bool // 是否允许直接预览 - SourceName string `gorm:"index:source"` // 用于搜索的字段 - - // 数据库忽略字段 - User User `gorm:"PRELOAD:false,association_autoupdate:false"` - File File `gorm:"PRELOAD:false,association_autoupdate:false"` - Folder Folder `gorm:"PRELOAD:false,association_autoupdate:false"` -} - -// Create 创建分享 -func (share *Share) Create() (uint, error) { - if err := DB.Create(share).Error; err != nil { - util.Log().Warning("Failed to insert share record: %s", err) - return 0, err - } - return share.ID, nil -} - -// GetShareByHashID 根据HashID查找分享 -func GetShareByHashID(hashID string) *Share { - id, err := hashid.DecodeHashID(hashID, hashid.ShareID) - if err != nil { - return nil - } - var share Share - result := DB.First(&share, id) - if result.Error != nil { - return nil - } - - return &share -} - -// IsAvailable 返回此分享是否可用(是否过期) -func (share *Share) IsAvailable() bool { - if share.RemainDownloads == 0 { - return false - } - if share.Expires != nil && time.Now().After(*share.Expires) { - return false - } - - // 检查创建者状态 - if share.Creator().Status != Active { - return false - } - - // 检查源对象是否存在 - var sourceID uint - if share.IsDir { - folder := share.SourceFolder() - sourceID = folder.ID - } else { - file := share.SourceFile() - sourceID = file.ID - } - if sourceID == 0 { - // TODO 是否要在这里删除这个无效分享? - return false - } - - return true -} - -// Creator 获取分享的创建者 -func (share *Share) Creator() *User { - if share.User.ID == 0 { - share.User, _ = GetUserByID(share.UserID) - } - return &share.User -} - -// Source 返回源对象 -func (share *Share) Source() interface{} { - if share.IsDir { - return share.SourceFolder() - } - return share.SourceFile() -} - -// SourceFolder 获取源目录 -func (share *Share) SourceFolder() *Folder { - if share.Folder.ID == 0 { - folders, _ := GetFoldersByIDs([]uint{share.SourceID}, share.UserID) - if len(folders) > 0 { - share.Folder = folders[0] - } - } - return &share.Folder -} - -// SourceFile 获取源文件 -func (share *Share) SourceFile() *File { - if share.File.ID == 0 { - files, _ := GetFilesByIDs([]uint{share.SourceID}, share.UserID) - if len(files) > 0 { - share.File = files[0] - } - } - return &share.File -} - -// CanBeDownloadBy 返回此分享是否可以被给定用户下载 -func (share *Share) CanBeDownloadBy(user *User) error { - // 用户组权限 - if !user.Group.OptionsSerialized.ShareDownload { - if user.IsAnonymous() { - return errors.New("you must login to download") - } - return errors.New("your group has no permission to download") - } - return nil -} - -// WasDownloadedBy 返回分享是否已被用户下载过 -func (share *Share) WasDownloadedBy(user *User, c *gin.Context) (exist bool) { - if user.IsAnonymous() { - exist = util.GetSession(c, fmt.Sprintf("share_%d_%d", share.ID, user.ID)) != nil - } else { - _, exist = cache.Get(fmt.Sprintf("share_%d_%d", share.ID, user.ID)) - } - - return exist -} - -// DownloadBy 增加下载次数,匿名用户不会缓存 -func (share *Share) DownloadBy(user *User, c *gin.Context) error { - if !share.WasDownloadedBy(user, c) { - share.Downloaded() - if !user.IsAnonymous() { - cache.Set(fmt.Sprintf("share_%d_%d", share.ID, user.ID), true, - GetIntSetting("share_download_session_timeout", 2073600)) - } else { - util.SetSession(c, map[string]interface{}{fmt.Sprintf("share_%d_%d", share.ID, user.ID): true}) - } - } - return nil -} - -// Viewed 增加访问次数 -func (share *Share) Viewed() { - share.Views++ - DB.Model(share).UpdateColumn("views", gorm.Expr("views + ?", 1)) -} - -// Downloaded 增加下载次数 -func (share *Share) Downloaded() { - share.Downloads++ - if share.RemainDownloads > 0 { - share.RemainDownloads-- - } - DB.Model(share).Updates(map[string]interface{}{ - "downloads": share.Downloads, - "remain_downloads": share.RemainDownloads, - }) -} - -// Update 更新分享属性 -func (share *Share) Update(props map[string]interface{}) error { - return DB.Model(share).Updates(props).Error -} - -// Delete 删除分享 -func (share *Share) Delete() error { - return DB.Model(share).Delete(share).Error -} - -// DeleteShareBySourceIDs 根据原始资源类型和ID删除文件 -func DeleteShareBySourceIDs(sources []uint, isDir bool) error { - return DB.Where("source_id in (?) and is_dir = ?", sources, isDir).Delete(&Share{}).Error -} - -// ListShares 列出UID下的分享 -func ListShares(uid uint, page, pageSize int, order string, publicOnly bool) ([]Share, int) { - var ( - shares []Share - total int - ) - dbChain := DB - dbChain = dbChain.Where("user_id = ?", uid) - if publicOnly { - dbChain = dbChain.Where("password = ?", "") - } - - // 计算总数用于分页 - dbChain.Model(&Share{}).Count(&total) - - // 查询记录 - dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&shares) - return shares, total -} - -// SearchShares 根据关键字搜索分享 -func SearchShares(page, pageSize int, order, keywords string) ([]Share, int) { - var ( - shares []Share - total int - ) - - keywordList := strings.Split(keywords, " ") - availableList := make([]string, 0, len(keywordList)) - for i := 0; i < len(keywordList); i++ { - if len(keywordList[i]) > 0 { - availableList = append(availableList, keywordList[i]) - } - } - if len(availableList) == 0 { - return shares, 0 - } - - dbChain := DB - dbChain = dbChain.Where("password = ? and remain_downloads <> 0 and (expires is NULL or expires > ?) and source_name like ?", "", time.Now(), "%"+strings.Join(availableList, "%")+"%") - - // 计算总数用于分页 - dbChain.Model(&Share{}).Count(&total) - - // 查询记录 - dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&shares) - return shares, total -} diff --git a/models/share_test.go b/models/share_test.go deleted file mode 100644 index b3fdf0a1..00000000 --- a/models/share_test.go +++ /dev/null @@ -1,321 +0,0 @@ -package model - -import ( - "errors" - "net/http/httptest" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestShare_Create(t *testing.T) { - asserts := assert.New(t) - share := Share{UserID: 1} - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - id, err := share.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(2, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - id, err := share.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestGetShareByHashID(t *testing.T) { - asserts := assert.New(t) - conf.SystemConfig.HashIDSalt = "" - - // 成功 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res := GetShareByHashID("x9T4") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(res) - } - - // 查询失败 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - res := GetShareByHashID("x9T4") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(res) - } - - // ID解码失败 - { - res := GetShareByHashID("empty") - asserts.Nil(res) - } - -} - -func TestShare_IsAvailable(t *testing.T) { - asserts := assert.New(t) - - // 下载剩余次数为0 - { - share := Share{} - asserts.False(share.IsAvailable()) - } - - // 时效过期 - { - expires := time.Unix(10, 10) - share := Share{ - RemainDownloads: -1, - Expires: &expires, - } - asserts.False(share.IsAvailable()) - } - - // 源对象为目录,但不存在 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: true, - } - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - asserts.False(share.IsAvailable()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 源对象为目录,存在 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: false, - } - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(13)) - asserts.True(share.IsAvailable()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 用户被封禁 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: true, - User: User{Status: Baned}, - } - asserts.False(share.IsAvailable()) - } -} - -func TestShare_GetCreator(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - share := Share{UserID: 1} - res := share.Creator() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(1, res.ID) -} - -func TestShare_Source(t *testing.T) { - asserts := assert.New(t) - - // 目录 - { - share := Share{IsDir: true, SourceID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - asserts.EqualValues(3, share.Source().(*Folder).ID) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 文件 - { - share := Share{IsDir: false, SourceID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - asserts.EqualValues(3, share.Source().(*File).ID) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestShare_CanBeDownloadBy(t *testing.T) { - asserts := assert.New(t) - share := Share{} - - // 未登录,无权 - { - user := &User{ - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: false, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - - // 已登录,无权 - { - user := &User{ - Model: gorm.Model{ID: 1}, - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: false, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - - // 成功 - { - user := &User{ - Model: gorm.Model{ID: 1}, - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: true, - }, - }, - } - asserts.NoError(share.CanBeDownloadBy(user)) - } -} - -func TestShare_WasDownloadedBy(t *testing.T) { - asserts := assert.New(t) - share := Share{ - Model: gorm.Model{ID: 1}, - } - - // 已登录,已下载 - { - user := User{ - Model: gorm.Model{ - ID: 1, - }, - } - r := httptest.NewRecorder() - c, _ := gin.CreateTestContext(r) - cache.Set("share_1_1", true, 0) - asserts.True(share.WasDownloadedBy(&user, c)) - } -} - -func TestShare_DownloadBy(t *testing.T) { - asserts := assert.New(t) - share := Share{ - Model: gorm.Model{ID: 1}, - } - user := User{ - Model: gorm.Model{ - ID: 1, - }, - } - cache.Deletes([]string{"1_1"}, "share_") - r := httptest.NewRecorder() - c, _ := gin.CreateTestContext(r) - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := share.DownloadBy(&user, c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - _, ok := cache.Get("share_1_1") - asserts.True(ok) -} - -func TestShare_Viewed(t *testing.T) { - asserts := assert.New(t) - share := Share{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - share.Viewed() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(1, share.Views) -} - -func TestShare_UpdateAndDelete(t *testing.T) { - asserts := assert.New(t) - share := Share{} - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Update(map[string]interface{}{"id": 1}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Delete() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteShareBySourceIDs([]uint{1}, true) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - -} - -func TestListShares(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(2)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2)) - - res, total := ListShares(1, 1, 10, "desc", true) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - asserts.Equal(2, total) -} - -func TestSearchShares(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)"). - WithArgs("", sqlmock.AnyArg(), "%1%2%"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, total := SearchShares(1, 10, "id", "1 2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 1) - asserts.Equal(1, total) -} diff --git a/models/source_link.go b/models/source_link.go deleted file mode 100644 index 49dfea28..00000000 --- a/models/source_link.go +++ /dev/null @@ -1,47 +0,0 @@ -package model - -import ( - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/jinzhu/gorm" - "net/url" -) - -// SourceLink represent a shared file source link -type SourceLink struct { - gorm.Model - FileID uint // corresponding file ID - Name string // name of the file while creating the source link, for annotation - Downloads int // 下载数 - - // 关联模型 - File File `gorm:"save_associations:false:false"` -} - -// Link gets the URL of a SourceLink -func (s *SourceLink) Link() (string, error) { - baseURL := GetSiteURL() - linkPath, err := url.Parse(fmt.Sprintf("/f/%s/%s", hashid.HashID(s.ID, hashid.SourceLinkID), s.File.Name)) - if err != nil { - return "", err - } - return baseURL.ResolveReference(linkPath).String(), nil -} - -// GetTasksByID queries source link based on ID -func GetSourceLinkByID(id interface{}) (*SourceLink, error) { - link := &SourceLink{} - result := DB.Where("id = ?", id).First(link) - files, _ := GetFilesByIDs([]uint{link.FileID}, 0) - if len(files) > 0 { - link.File = files[0] - } - - return link, result.Error -} - -// Viewed 增加访问次数 -func (s *SourceLink) Downloaded() { - s.Downloads++ - DB.Model(s).UpdateColumn("downloads", gorm.Expr("downloads + ?", 1)) -} diff --git a/models/source_link_test.go b/models/source_link_test.go deleted file mode 100644 index d84dc628..00000000 --- a/models/source_link_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestSourceLink_Link(t *testing.T) { - a := assert.New(t) - s := &SourceLink{} - s.ID = 1 - - // 失败 - { - s.File.Name = string([]byte{0x7f}) - res, err := s.Link() - a.Error(err) - a.Empty(res) - } - - // 成功 - { - s.File.Name = "filename" - res, err := s.Link() - a.NoError(err) - a.Contains(res, s.Name) - } -} - -func TestGetSourceLinkByID(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - - res, err := GetSourceLinkByID(1) - a.NoError(err) - a.NotNil(res) - a.EqualValues(2, res.File.ID) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestSourceLink_Downloaded(t *testing.T) { - a := assert.New(t) - s := &SourceLink{} - s.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - s.Downloaded() - a.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/tag.go b/models/tag.go deleted file mode 100644 index 5ce1a4db..00000000 --- a/models/tag.go +++ /dev/null @@ -1,53 +0,0 @@ -package model - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// Tag 用户自定义标签 -type Tag struct { - gorm.Model - Name string // 标签名 - Icon string // 图标标识 - Color string // 图标颜色 - Type int // 标签类型(文件分类/目录直达) - Expression string `gorm:"type:text"` // 搜索表表达式/直达路径 - UserID uint // 创建者ID -} - -const ( - // FileTagType 文件分类标签 - FileTagType = iota - // DirectoryLinkType 目录快捷方式标签 - DirectoryLinkType -) - -// Create 创建标签记录 -func (tag *Tag) Create() (uint, error) { - if err := DB.Create(tag).Error; err != nil { - util.Log().Warning("Failed to insert tag record: %s", err) - return 0, err - } - return tag.ID, nil -} - -// DeleteTagByID 根据给定ID和用户ID删除标签 -func DeleteTagByID(id, uid uint) error { - result := DB.Where("id = ? and user_id = ?", id, uid).Delete(&Tag{}) - return result.Error -} - -// GetTagsByUID 根据用户ID查找标签 -func GetTagsByUID(uid uint) ([]Tag, error) { - var tag []Tag - result := DB.Where("user_id = ?", uid).Find(&tag) - return tag, result.Error -} - -// GetTagsByID 根据ID查找标签 -func GetTagsByID(id, uid uint) (*Tag, error) { - var tag Tag - result := DB.Where("user_id = ? and id = ?", uid, id).First(&tag) - return &tag, result.Error -} diff --git a/models/tag_test.go b/models/tag_test.go deleted file mode 100644 index be8d3fb5..00000000 --- a/models/tag_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTag_Create(t *testing.T) { - asserts := assert.New(t) - tag := Tag{} - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - id, err := tag.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - id, err := tag.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestDeleteTagByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteTagByID(1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) -} - -func TestGetTagsByUID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetTagsByUID(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) -} - -func TestGetTagsByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag")) - res, err := GetTagsByID(1, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues("tag", res.Name) -} diff --git a/models/task.go b/models/task.go deleted file mode 100644 index a6fde2ef..00000000 --- a/models/task.go +++ /dev/null @@ -1,73 +0,0 @@ -package model - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" -) - -// Task 任务模型 -type Task struct { - gorm.Model - Status int // 任务状态 - Type int // 任务类型 - UserID uint // 发起者UID,0表示为系统发起 - Progress int // 进度 - Error string `gorm:"type:text"` // 错误信息 - Props string `gorm:"type:text"` // 任务属性 -} - -// Create 创建任务记录 -func (task *Task) Create() (uint, error) { - if err := DB.Create(task).Error; err != nil { - util.Log().Warning("Failed to insert task record: %s", err) - return 0, err - } - return task.ID, nil -} - -// SetStatus 设定任务状态 -func (task *Task) SetStatus(status int) error { - return DB.Model(task).Select("status").Updates(map[string]interface{}{"status": status}).Error -} - -// SetProgress 设定任务进度 -func (task *Task) SetProgress(progress int) error { - return DB.Model(task).Select("progress").Updates(map[string]interface{}{"progress": progress}).Error -} - -// SetError 设定错误信息 -func (task *Task) SetError(err string) error { - return DB.Model(task).Select("error").Updates(map[string]interface{}{"error": err}).Error -} - -// GetTasksByStatus 根据状态检索任务 -func GetTasksByStatus(status ...int) []Task { - var tasks []Task - DB.Where("status in (?)", status).Find(&tasks) - return tasks -} - -// GetTasksByID 根据ID检索任务 -func GetTasksByID(id interface{}) (*Task, error) { - task := &Task{} - result := DB.Where("id = ?", id).First(task) - return task, result.Error -} - -// ListTasks 列出用户所属的任务 -func ListTasks(uid uint, page, pageSize int, order string) ([]Task, int) { - var ( - tasks []Task - total int - ) - dbChain := DB - dbChain = dbChain.Where("user_id = ?", uid) - - // 计算总数用于分页 - dbChain.Model(&Task{}).Count(&total) - - // 查询记录 - dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks) - - return tasks, total -} diff --git a/models/task_test.go b/models/task_test.go deleted file mode 100644 index 1ad71c3d..00000000 --- a/models/task_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTask_Create(t *testing.T) { - asserts := assert.New(t) - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task := Task{Props: "1"} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - task := Task{Props: "1"} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetError("error")) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetStatus(1)) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTask_SetProgress(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetProgress(1)) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetTasksByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetTasksByID(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, res.ID) -} - -func TestListTasks(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5)) - - res, total := ListTasks(1, 1, 10, "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(5, total) - asserts.Len(res, 1) -} - -func TestGetTasksByStatus(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res := GetTasksByStatus(1, 2) - a.NoError(mock.ExpectationsWereMet()) - a.Len(res, 1) -} diff --git a/models/user.go b/models/user.go deleted file mode 100644 index ff1d6dd6..00000000 --- a/models/user.go +++ /dev/null @@ -1,290 +0,0 @@ -package model - -import ( - "crypto/md5" - "crypto/sha1" - "encoding/gob" - "encoding/hex" - "encoding/json" - "strings" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/pkg/errors" -) - -const ( - // Active 账户正常状态 - Active = iota - // NotActivicated 未激活 - NotActivicated - // Baned 被封禁 - Baned - // OveruseBaned 超额使用被封禁 - OveruseBaned -) - -// User 用户模型 -type User struct { - // 表字段 - gorm.Model - Email string `gorm:"type:varchar(100);unique_index"` - Nick string `gorm:"size:50"` - Password string `json:"-"` - Status int - GroupID uint - Storage uint64 - TwoFactor string - Avatar string - Options string `json:"-" gorm:"size:4294967295"` - Authn string `gorm:"size:4294967295"` - - // 关联模型 - Group Group `gorm:"save_associations:false:false"` - Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` - - // 数据库忽略字段 - OptionsSerialized UserOption `gorm:"-"` -} - -func init() { - gob.Register(User{}) -} - -// UserOption 用户个性化配置字段 -type UserOption struct { - ProfileOff bool `json:"profile_off,omitempty"` - PreferredTheme string `json:"preferred_theme,omitempty"` -} - -// Root 获取用户的根目录 -func (user *User) Root() (*Folder, error) { - var folder Folder - err := DB.Where("parent_id is NULL AND owner_id = ?", user.ID).First(&folder).Error - return &folder, err -} - -// DeductionStorage 减少用户已用容量 -func (user *User) DeductionStorage(size uint64) bool { - if size == 0 { - return true - } - if size <= user.Storage { - user.Storage -= size - DB.Model(user).Update("storage", gorm.Expr("storage - ?", size)) - return true - } - // 如果要减少的容量超出已用容量,则设为零 - user.Storage = 0 - DB.Model(user).Update("storage", 0) - - return false -} - -// IncreaseStorage 检查并增加用户已用容量 -func (user *User) IncreaseStorage(size uint64) bool { - if size == 0 { - return true - } - if size <= user.GetRemainingCapacity() { - user.Storage += size - DB.Model(user).Update("storage", gorm.Expr("storage + ?", size)) - return true - } - return false -} - -// ChangeStorage 更新用户容量 -func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error { - return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error -} - -// IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量 -func (user *User) IncreaseStorageWithoutCheck(size uint64) { - if size == 0 { - return - } - user.Storage += size - DB.Model(user).Update("storage", gorm.Expr("storage + ?", size)) - -} - -// GetRemainingCapacity 获取剩余配额 -func (user *User) GetRemainingCapacity() uint64 { - total := user.Group.MaxStorage - if total <= user.Storage { - return 0 - } - return total - user.Storage -} - -// GetPolicyID 获取用户当前的存储策略ID -func (user *User) GetPolicyID(prefer uint) uint { - if len(user.Group.PolicyList) > 0 { - return user.Group.PolicyList[0] - } - return 0 -} - -// GetUserByID 用ID获取用户 -func GetUserByID(ID interface{}) (User, error) { - var user User - result := DB.Set("gorm:auto_preload", true).First(&user, ID) - return user, result.Error -} - -// GetActiveUserByID 用ID获取可登录用户 -func GetActiveUserByID(ID interface{}) (User, error) { - var user User - result := DB.Set("gorm:auto_preload", true).Where("status = ?", Active).First(&user, ID) - return user, result.Error -} - -// GetActiveUserByOpenID 用OpenID获取可登录用户 -func GetActiveUserByOpenID(openid string) (User, error) { - var user User - result := DB.Set("gorm:auto_preload", true).Where("status = ? and open_id = ?", Active, openid).Find(&user) - return user, result.Error -} - -// GetUserByEmail 用Email获取用户 -func GetUserByEmail(email string) (User, error) { - var user User - result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user) - return user, result.Error -} - -// GetActiveUserByEmail 用Email获取可登录用户 -func GetActiveUserByEmail(email string) (User, error) { - var user User - result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user) - return user, result.Error -} - -// NewUser 返回一个新的空 User -func NewUser() User { - options := UserOption{} - return User{ - OptionsSerialized: options, - } -} - -// BeforeSave Save用户前的钩子 -func (user *User) BeforeSave() (err error) { - err = user.SerializeOptions() - return err -} - -// AfterCreate 创建用户后的钩子 -func (user *User) AfterCreate(tx *gorm.DB) (err error) { - // 创建用户的默认根目录 - defaultFolder := &Folder{ - Name: "/", - OwnerID: user.ID, - } - tx.Create(defaultFolder) - return err -} - -// AfterFind 找到用户后的钩子 -func (user *User) AfterFind() (err error) { - // 解析用户设置到OptionsSerialized - if user.Options != "" { - err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) - } - - // 预加载存储策略 - user.Policy, _ = GetPolicyByID(user.GetPolicyID(0)) - return err -} - -//SerializeOptions 将序列后的Option写入到数据库字段 -func (user *User) SerializeOptions() (err error) { - optionsValue, err := json.Marshal(&user.OptionsSerialized) - user.Options = string(optionsValue) - return err -} - -// CheckPassword 根据明文校验密码 -func (user *User) CheckPassword(password string) (bool, error) { - - // 根据存储密码拆分为 Salt 和 Digest - passwordStore := strings.Split(user.Password, ":") - if len(passwordStore) != 2 && len(passwordStore) != 3 { - return false, errors.New("Unknown password type") - } - - // 兼容V2密码,升级后存储格式为: md5:$HASH:$SALT - if len(passwordStore) == 3 { - if passwordStore[0] != "md5" { - return false, errors.New("Unknown password type") - } - hash := md5.New() - _, err := hash.Write([]byte(passwordStore[2] + password)) - bs := hex.EncodeToString(hash.Sum(nil)) - if err != nil { - return false, err - } - return bs == passwordStore[1], nil - } - - //计算 Salt 和密码组合的SHA1摘要 - hash := sha1.New() - _, err := hash.Write([]byte(password + passwordStore[0])) - bs := hex.EncodeToString(hash.Sum(nil)) - if err != nil { - return false, err - } - - return bs == passwordStore[1], nil -} - -// SetPassword 根据给定明文设定 User 的 Password 字段 -func (user *User) SetPassword(password string) error { - //生成16位 Salt - salt := util.RandStringRunes(16) - - //计算 Salt 和密码组合的SHA1摘要 - hash := sha1.New() - _, err := hash.Write([]byte(password + salt)) - bs := hex.EncodeToString(hash.Sum(nil)) - - if err != nil { - return err - } - - //存储 Salt 值和摘要, ":"分割 - user.Password = salt + ":" + string(bs) - return nil -} - -// NewAnonymousUser 返回一个匿名用户 -func NewAnonymousUser() *User { - user := User{} - user.Policy.Type = "anonymous" - user.Group, _ = GetGroupByID(3) - return &user -} - -// IsAnonymous 返回是否为未登录用户 -func (user *User) IsAnonymous() bool { - return user.ID == 0 -} - -// SetStatus 设定用户状态 -func (user *User) SetStatus(status int) { - DB.Model(&user).Update("status", status) -} - -// Update 更新用户 -func (user *User) Update(val map[string]interface{}) error { - return DB.Model(user).Updates(val).Error -} - -// UpdateOptions 更新用户偏好设定 -func (user *User) UpdateOptions() error { - if err := user.SerializeOptions(); err != nil { - return err - } - return user.Update(map[string]interface{}{"options": user.Options}) -} diff --git a/models/user_authn.go b/models/user_authn.go deleted file mode 100644 index ba329bf1..00000000 --- a/models/user_authn.go +++ /dev/null @@ -1,79 +0,0 @@ -package model - -import ( - "encoding/base64" - "encoding/binary" - "encoding/json" - "fmt" - "net/url" - - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/duo-labs/webauthn/webauthn" -) - -/* - `webauthn.User` 接口的实现 -*/ - -// WebAuthnID 返回用户ID -func (user User) WebAuthnID() []byte { - bs := make([]byte, 8) - binary.LittleEndian.PutUint64(bs, uint64(user.ID)) - return bs -} - -// WebAuthnName 返回用户名 -func (user User) WebAuthnName() string { - return user.Email -} - -// WebAuthnDisplayName 获得用于展示的用户名 -func (user User) WebAuthnDisplayName() string { - return user.Nick -} - -// WebAuthnIcon 获得用户头像 -func (user User) WebAuthnIcon() string { - avatar, _ := url.Parse("/api/v3/user/avatar/" + hashid.HashID(user.ID, hashid.UserID) + "/l") - base := GetSiteURL() - base.Scheme = "https" - return base.ResolveReference(avatar).String() -} - -// WebAuthnCredentials 获得已注册的验证器凭证 -func (user User) WebAuthnCredentials() []webauthn.Credential { - var res []webauthn.Credential - err := json.Unmarshal([]byte(user.Authn), &res) - if err != nil { - fmt.Println(err) - } - return res -} - -// RegisterAuthn 添加新的验证器 -func (user *User) RegisterAuthn(credential *webauthn.Credential) error { - exists := user.WebAuthnCredentials() - exists = append(exists, *credential) - res, err := json.Marshal(exists) - if err != nil { - return err - } - - return DB.Model(user).Update("authn", string(res)).Error -} - -// RemoveAuthn 删除验证器 -func (user *User) RemoveAuthn(id string) { - exists := user.WebAuthnCredentials() - for i := 0; i < len(exists); i++ { - idEncoded := base64.StdEncoding.EncodeToString(exists[i].ID) - if idEncoded == id { - exists[len(exists)-1], exists[i] = exists[i], exists[len(exists)-1] - exists = exists[:len(exists)-1] - break - } - } - - res, _ := json.Marshal(exists) - DB.Model(user).Update("authn", string(res)) -} diff --git a/models/user_authn_test.go b/models/user_authn_test.go deleted file mode 100644 index 08a8ce19..00000000 --- a/models/user_authn_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/duo-labs/webauthn/webauthn" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestUser_RegisterAuthn(t *testing.T) { - asserts := assert.New(t) - credential := webauthn.Credential{} - user := User{ - Model: gorm.Model{ID: 1}, - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.RegisterAuthn(&credential) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestUser_WebAuthnCredentials(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, - } - { - credentials := user.WebAuthnCredentials() - asserts.Len(credentials, 1) - } -} - -func TestUser_WebAuthnDisplayName(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Nick: "123", - } - { - nick := user.WebAuthnDisplayName() - asserts.Equal("123", nick) - } -} - -func TestUser_WebAuthnIcon(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - } - { - icon := user.WebAuthnIcon() - asserts.NotEmpty(icon) - } -} - -func TestUser_WebAuthnID(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - } - { - id := user.WebAuthnID() - asserts.Len(id, 8) - } -} - -func TestUser_WebAuthnName(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Email: "abslant@foxmail.com", - } - { - name := user.WebAuthnName() - asserts.Equal("abslant@foxmail.com", name) - } -} - -func TestUser_RemoveAuthn(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, - } - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.RemoveAuthn("123") - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/user_test.go b/models/user_test.go deleted file mode 100644 index a85ddbd3..00000000 --- a/models/user_test.go +++ /dev/null @@ -1,438 +0,0 @@ -package model - -import ( - "encoding/json" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestGetUserByID(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") - //找到用户时 - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). - AddRow(1, nil, "admin@cloudreve.org", "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - user, err := GetUserByID(1) - asserts.NoError(err) - asserts.Equal(User{ - Model: gorm.Model{ - ID: 1, - DeletedAt: nil, - }, - Email: "admin@cloudreve.org", - Options: "{}", - GroupID: 1, - Group: Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, - Policy: Policy{ - Model: gorm.Model{ - ID: 1, - }, - OptionsSerialized: PolicyOption{ - FileType: []string{}, - }, - Name: "默认存储策略", - }, - }, user) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - user, err = GetUserByID(1) - asserts.Error(err) - asserts.Equal(User{}, user) -} - -func TestGetActiveUserByID(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") - //找到用户时 - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). - AddRow(1, nil, "admin@cloudreve.org", "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - user, err := GetActiveUserByID(1) - asserts.NoError(err) - asserts.Equal(User{ - Model: gorm.Model{ - ID: 1, - DeletedAt: nil, - }, - Email: "admin@cloudreve.org", - Options: "{}", - GroupID: 1, - Group: Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, - Policy: Policy{ - Model: gorm.Model{ - ID: 1, - }, - OptionsSerialized: PolicyOption{ - FileType: []string{}, - }, - Name: "默认存储策略", - }, - }, user) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - user, err = GetActiveUserByID(1) - asserts.Error(err) - asserts.Equal(User{}, user) -} - -func TestUser_SetPassword(t *testing.T) { - asserts := assert.New(t) - user := User{} - err := user.SetPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.NotEmpty(user.Password) -} - -func TestUser_CheckPassword(t *testing.T) { - asserts := assert.New(t) - user := User{} - err := user.SetPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - - //密码正确 - res, err := user.CheckPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.True(res) - - //密码错误 - res, err = user.CheckPassword("Cause Sega does what Nintendon't") - asserts.NoError(err) - asserts.False(res) - - //密码字段为空 - user = User{} - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.Error(err) - asserts.False(res) - - // 未知密码类型 - user = User{} - user.Password = "1:2:3" - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.Error(err) - asserts.False(res) - - // V2密码,错误 - user = User{} - user.Password = "md5:2:3" - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.False(res) - - // V2密码,正确 - user = User{} - user.Password = "md5:d8446059f8846a2c111a7f53515665fb:sdshare" - res, err = user.CheckPassword("admin") - asserts.NoError(err) - asserts.True(res) - -} - -func TestNewUser(t *testing.T) { - asserts := assert.New(t) - newUser := NewUser() - asserts.IsType(User{}, newUser) - asserts.Empty(newUser.Avatar) -} - -func TestUser_AfterFind(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"0"}, "policy_") - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(144, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - newUser := NewUser() - err := newUser.AfterFind() - err = newUser.BeforeSave() - expected := UserOption{} - err = json.Unmarshal([]byte(newUser.Options), &expected) - - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(expected, newUser.OptionsSerialized) - asserts.Equal("默认存储策略", newUser.Policy.Name) -} - -func TestUser_BeforeSave(t *testing.T) { - asserts := assert.New(t) - - newUser := NewUser() - err := newUser.BeforeSave() - expected, err := json.Marshal(newUser.OptionsSerialized) - - asserts.NoError(err) - asserts.Equal(string(expected), newUser.Options) -} - -func TestUser_GetPolicyID(t *testing.T) { - asserts := assert.New(t) - - newUser := NewUser() - newUser.Group.PolicyList = []uint{1} - - asserts.EqualValues(1, newUser.GetPolicyID(0)) - - newUser.Group.PolicyList = nil - asserts.EqualValues(0, newUser.GetPolicyID(0)) - - newUser.Group.PolicyList = []uint{} - asserts.EqualValues(0, newUser.GetPolicyID(0)) -} - -func TestUser_GetRemainingCapacity(t *testing.T) { - asserts := assert.New(t) - newUser := NewUser() - cache.Set("pack_size_0", uint64(0), 0) - - newUser.Group.MaxStorage = 100 - asserts.Equal(uint64(100), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 1 - asserts.Equal(uint64(99), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 100 - asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 200 - asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) -} - -func TestUser_DeductionCapacity(t *testing.T) { - asserts := assert.New(t) - - cache.Deletes([]string{"1"}, "policy_") - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}). - AddRow(1, nil, 0, "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - newUser, err := GetUserByID(1) - newUser.Group.MaxStorage = 100 - cache.Set("pack_size_1", uint64(0), 0) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - asserts.Equal(false, newUser.IncreaseStorage(101)) - asserts.Equal(uint64(0), newUser.Storage) - - asserts.Equal(true, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(1), newUser.Storage) - - asserts.Equal(true, newUser.IncreaseStorage(99)) - asserts.Equal(uint64(100), newUser.Storage) - - asserts.Equal(false, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(100), newUser.Storage) - - asserts.True(newUser.IncreaseStorage(0)) -} - -func TestUser_DeductionStorage(t *testing.T) { - asserts := assert.New(t) - - // 减少零 - { - user := User{Storage: 1} - asserts.True(user.DeductionStorage(0)) - asserts.Equal(uint64(1), user.Storage) - } - // 正常 - { - user := User{ - Model: gorm.Model{ID: 1}, - Storage: 10, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(5, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.True(user.DeductionStorage(5)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(5), user.Storage) - } - - // 减少的超出可用的 - { - user := User{ - Model: gorm.Model{ID: 1}, - Storage: 10, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(0, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.False(user.DeductionStorage(20)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(0), user.Storage) - } -} - -func TestUser_IncreaseStorageWithoutCheck(t *testing.T) { - asserts := assert.New(t) - - // 增加零 - { - user := User{} - user.IncreaseStorageWithoutCheck(0) - asserts.Equal(uint64(0), user.Storage) - } - - // 减少零 - { - user := User{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(10, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - user.IncreaseStorageWithoutCheck(10) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(10), user.Storage) - } -} - -func TestGetActiveUserByEmail(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) - _, err := GetActiveUserByEmail("abslant@foxmail.com") - - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetUserByEmail(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) - _, err := GetUserByEmail("abslant@foxmail.com") - - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestUser_AfterCreate(t *testing.T) { - asserts := assert.New(t) - user := User{Model: gorm.Model{ID: 1}} - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := user.AfterCreate(DB) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestUser_Root(t *testing.T) { - asserts := assert.New(t) - user := User{Model: gorm.Model{ID: 1}} - - // 根目录存在 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "根目录")) - root, err := user.Root() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("根目录", root.Name) - } - - // 根目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - _, err := user.Root() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestNewAnonymousUser(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - user := NewAnonymousUser() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(user) - asserts.EqualValues(3, user.Group.ID) -} - -func TestUser_IsAnonymous(t *testing.T) { - asserts := assert.New(t) - user := User{} - asserts.True(user.IsAnonymous()) - user.ID = 1 - asserts.False(user.IsAnonymous()) -} - -func TestUser_SetStatus(t *testing.T) { - asserts := assert.New(t) - user := User{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.SetStatus(Baned) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(Baned, user.Status) -} - -func TestUser_UpdateOptions(t *testing.T) { - asserts := assert.New(t) - user := User{} - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.NoError(user.UpdateOptions()) - asserts.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/webdav.go b/models/webdav.go deleted file mode 100644 index 0799aee9..00000000 --- a/models/webdav.go +++ /dev/null @@ -1,48 +0,0 @@ -package model - -import ( - "github.com/jinzhu/gorm" -) - -// Webdav 应用账户 -type Webdav struct { - gorm.Model - Name string // 应用名称 - Password string `gorm:"unique_index:password_only_on"` // 应用密码 - UserID uint `gorm:"unique_index:password_only_on"` // 用户ID - Root string `gorm:"type:text"` // 根目录 - Readonly bool `gorm:"type:bool"` // 是否只读 - UseProxy bool `gorm:"type:bool"` // 是否进行反代 -} - -// Create 创建账户 -func (webdav *Webdav) Create() (uint, error) { - if err := DB.Create(webdav).Error; err != nil { - return 0, err - } - return webdav.ID, nil -} - -// GetWebdavByPassword 根据密码和用户查找Webdav应用 -func GetWebdavByPassword(password string, uid uint) (*Webdav, error) { - webdav := &Webdav{} - res := DB.Where("user_id = ? and password = ?", uid, password).First(webdav) - return webdav, res.Error -} - -// ListWebDAVAccounts 列出用户的所有账号 -func ListWebDAVAccounts(uid uint) []Webdav { - var accounts []Webdav - DB.Where("user_id = ?", uid).Order("created_at desc").Find(&accounts) - return accounts -} - -// DeleteWebDAVAccountByID 根据账户ID和UID删除账户 -func DeleteWebDAVAccountByID(id, uid uint) { - DB.Where("user_id = ? and id = ?", uid, id).Delete(&Webdav{}) -} - -// UpdateWebDAVAccountByID 根据账户ID和UID更新账户 -func UpdateWebDAVAccountByID(id, uid uint, updates map[string]interface{}) { - DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).Updates(updates) -} diff --git a/models/webdav_test.go b/models/webdav_test.go deleted file mode 100644 index 55a7326f..00000000 --- a/models/webdav_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestWebdav_Create(t *testing.T) { - asserts := assert.New(t) - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task := Webdav{} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - task := Webdav{} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestGetWebdavByPassword(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := GetWebdavByPassword("e", 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) -} - -func TestListWebDAVAccounts(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - res := ListWebDAVAccounts(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 0) -} - -func TestDeleteWebDAVAccountByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - DeleteWebDAVAccountByID(1, 1) - asserts.NoError(mock.ExpectationsWereMet()) -} diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go deleted file mode 100644 index f91766fa..00000000 --- a/pkg/aria2/aria2.go +++ /dev/null @@ -1,67 +0,0 @@ -package aria2 - -import ( - "context" - "fmt" - "net/url" - "sync" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" -) - -// Instance 默认使用的Aria2处理实例 -var Instance common.Aria2 = &common.DummyAria2{} - -// LB 获取 Aria2 节点的负载均衡器 -var LB balancer.Balancer - -// Lock Instance的读写锁 -var Lock sync.RWMutex - -// GetLoadBalancer 返回供Aria2使用的负载均衡器 -func GetLoadBalancer() balancer.Balancer { - Lock.RLock() - defer Lock.RUnlock() - return LB -} - -// Init 初始化 -func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) { - Lock.Lock() - LB = balancer.NewBalancer("RoundRobin") - Lock.Unlock() - - if !isReload { - // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding) - - for i := 0; i < len(unfinished); i++ { - // 创建任务监控 - monitor.NewMonitor(&unfinished[i], pool, mqClient) - } - } -} - -// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性 -func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) { - // 解析RPC服务地址 - rpcServer, err := url.Parse(server) - if err != nil { - return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err) - } - - rpcServer.Path = "/jsonrpc" - caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil) - if err != nil { - return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err) - } - - return caller.GetVersion() -} diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go deleted file mode 100644 index b6e7092d..00000000 --- a/pkg/aria2/aria2_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package aria2 - -import ( - "database/sql" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInit(t *testing.T) { - a := assert.New(t) - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", testMock.Anything).Return(nil) - mockQueue := mq.NewMQ() - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - Init(false, mockPool, mockQueue) - a.NoError(mock.ExpectationsWereMet()) - mockPool.AssertExpectations(t) -} - -func TestTestRPCConnection(t *testing.T) { - a := assert.New(t) - - // url not legal - { - res, err := TestRPCConnection(string([]byte{0x7f}), "", 10) - a.Error(err) - a.Empty(res.Version) - } - - // rpc failed - { - res, err := TestRPCConnection("ws://0.0.0.0", "", 0) - a.Error(err) - a.Empty(res.Version) - } -} - -func TestGetLoadBalancer(t *testing.T) { - a := assert.New(t) - a.NotPanics(func() { - GetLoadBalancer() - }) -} diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go deleted file mode 100644 index ae5e6b02..00000000 --- a/pkg/aria2/common/common.go +++ /dev/null @@ -1,119 +0,0 @@ -package common - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Aria2 离线下载处理接口 -type Aria2 interface { - // Init 初始化客户端连接 - Init() error - // CreateTask 创建新的任务 - CreateTask(task *model.Download, options map[string]interface{}) (string, error) - // 返回状态信息 - Status(task *model.Download) (rpc.StatusInfo, error) - // 取消任务 - Cancel(task *model.Download) error - // 选择要下载的文件 - Select(task *model.Download, files []int) error - // 获取离线下载配置 - GetConfig() model.Aria2Option - // 删除临时下载文件 - DeleteTempFile(*model.Download) error -} - -const ( - // URLTask 从URL添加的任务 - URLTask = iota - // TorrentTask 种子任务 - TorrentTask -) - -const ( - // Ready 准备就绪 - Ready = iota - // Downloading 下载中 - Downloading - // Paused 暂停中 - Paused - // Error 出错 - Error - // Complete 完成 - Complete - // Canceled 取消/停止 - Canceled - // Unknown 未知状态 - Unknown - // Seeding 做种中 - Seeding -) - -var ( - // ErrNotEnabled 功能未开启错误 - ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil) - // ErrUserNotFound 未找到下载任务创建者 - ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil) -) - -// DummyAria2 未开启Aria2功能时使用的默认处理器 -type DummyAria2 struct { -} - -func (instance *DummyAria2) Init() error { - return nil -} - -// CreateTask 创建新任务,此处直接返回未开启错误 -func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) { - return "", ErrNotEnabled -} - -// Status 返回未开启错误 -func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) { - return rpc.StatusInfo{}, ErrNotEnabled -} - -// Cancel 返回未开启错误 -func (instance *DummyAria2) Cancel(task *model.Download) error { - return ErrNotEnabled -} - -// Select 返回未开启错误 -func (instance *DummyAria2) Select(task *model.Download, files []int) error { - return ErrNotEnabled -} - -// GetConfig 返回空的 -func (instance *DummyAria2) GetConfig() model.Aria2Option { - return model.Aria2Option{} -} - -// GetConfig 返回空的 -func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { - return ErrNotEnabled -} - -// GetStatus 将给定的状态字符串转换为状态标识数字 -func GetStatus(status rpc.StatusInfo) int { - switch status.Status { - case "complete": - return Complete - case "active": - if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength { - return Seeding - } - return Downloading - case "waiting": - return Ready - case "paused": - return Paused - case "error": - return Error - case "removed": - return Canceled - default: - return Unknown - } -} diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go deleted file mode 100644 index 7b0f2378..00000000 --- a/pkg/aria2/common/common_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package common - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/stretchr/testify/assert" -) - -func TestDummyAria2(t *testing.T) { - a := assert.New(t) - d := &DummyAria2{} - - a.NoError(d.Init()) - - res, err := d.CreateTask(&model.Download{}, map[string]interface{}{}) - a.Empty(res) - a.Error(err) - - _, err = d.Status(&model.Download{}) - a.Error(err) - - err = d.Cancel(&model.Download{}) - a.Error(err) - - err = d.Select(&model.Download{}, []int{}) - a.Error(err) - - configRes := d.GetConfig() - a.NotNil(configRes) - - err = d.DeleteTempFile(&model.Download{}) - a.Error(err) -} - -func TestGetStatus(t *testing.T) { - a := assert.New(t) - - a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: "single"}, - TotalLength: "100", CompletedLength: "50"}), Downloading) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: "multi"}, - TotalLength: "100", CompletedLength: "100"}), Seeding) - a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready) - a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused) - a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error) - a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled) - a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown) -} diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go deleted file mode 100644 index 69d14ffe..00000000 --- a/pkg/aria2/monitor/monitor.go +++ /dev/null @@ -1,314 +0,0 @@ -package monitor - -import ( - "context" - "encoding/json" - "errors" - "path/filepath" - "strconv" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// Monitor 离线下载状态监控 -type Monitor struct { - Task *model.Download - Interval time.Duration - - notifier <-chan mq.Message - node cluster.Node - retried int -} - -var MAX_RETRY = 10 - -// NewMonitor 新建离线下载状态监控 -func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) { - monitor := &Monitor{ - Task: task, - notifier: make(chan mq.Message), - node: pool.GetNodeByID(task.GetNodeID()), - } - - if monitor.node != nil { - monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second - go monitor.Loop(mqClient) - - monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0) - } else { - monitor.setErrorStatus(errors.New("node not avaliable")) - } -} - -// Loop 开启监控循环 -func (monitor *Monitor) Loop(mqClient mq.MQ) { - defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier) - - // 首次循环立即更新 - interval := 50 * time.Millisecond - - for { - select { - case <-monitor.notifier: - if monitor.Update() { - return - } - case <-time.After(interval): - interval = monitor.Interval - if monitor.Update() { - return - } - } - } -} - -// Update 更新状态,返回值表示是否退出监控 -func (monitor *Monitor) Update() bool { - status, err := monitor.node.GetAria2Instance().Status(monitor.Task) - - if err != nil { - monitor.retried++ - util.Log().Warning("Cannot get status of download task %q: %s", monitor.Task.GID, err) - - // 十次重试后认定为任务失败 - if monitor.retried > MAX_RETRY { - util.Log().Warning("Cannot get status of download task %q,exceed maximum retry threshold: %s", - monitor.Task.GID, err) - monitor.setErrorStatus(err) - monitor.RemoveTempFolder() - return true - } - - return false - } - monitor.retried = 0 - - // 磁力链下载需要跟随 - if len(status.FollowedBy) > 0 { - util.Log().Debug("Redirected download task from %q to %q.", monitor.Task.GID, status.FollowedBy[0]) - monitor.Task.GID = status.FollowedBy[0] - monitor.Task.Save() - return false - } - - // 更新任务信息 - if err := monitor.UpdateTaskInfo(status); err != nil { - util.Log().Warning("Failed to update status of download task %q: %s", monitor.Task.GID, err) - monitor.setErrorStatus(err) - monitor.RemoveTempFolder() - return true - } - - util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status) - - switch common.GetStatus(status) { - case common.Complete, common.Seeding: - return monitor.Complete(task.TaskPoll) - case common.Error: - return monitor.Error(status) - case common.Downloading, common.Ready, common.Paused: - return false - case common.Canceled: - monitor.Task.Status = common.Canceled - monitor.Task.Save() - monitor.RemoveTempFolder() - return true - default: - util.Log().Warning("Download task %q returns unknown status %q.", monitor.Task.GID, status.Status) - return true - } -} - -// UpdateTaskInfo 更新数据库中的任务信息 -func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { - originSize := monitor.Task.TotalSize - - monitor.Task.GID = status.Gid - monitor.Task.Status = common.GetStatus(status) - - // 文件大小、已下载大小 - total, err := strconv.ParseUint(status.TotalLength, 10, 64) - if err != nil { - total = 0 - } - downloaded, err := strconv.ParseUint(status.CompletedLength, 10, 64) - if err != nil { - downloaded = 0 - } - monitor.Task.TotalSize = total - monitor.Task.DownloadedSize = downloaded - monitor.Task.GID = status.Gid - monitor.Task.Parent = status.Dir - - // 下载速度 - speed, err := strconv.Atoi(status.DownloadSpeed) - if err != nil { - speed = 0 - } - - monitor.Task.Speed = speed - attrs, _ := json.Marshal(status) - monitor.Task.Attrs = string(attrs) - - if err := monitor.Task.Save(); err != nil { - return err - } - - if originSize != monitor.Task.TotalSize { - // 文件大小更新后,对文件限制等进行校验 - if err := monitor.ValidateFile(); err != nil { - // 验证失败时取消任务 - monitor.node.GetAria2Instance().Cancel(monitor.Task) - return err - } - } - - return nil -} - -// ValidateFile 上传过程中校验文件大小、文件名 -func (monitor *Monitor) ValidateFile() error { - // 找到任务创建者 - user := monitor.Task.GetOwner() - if user == nil { - return common.ErrUserNotFound - } - - // 创建文件系统 - fs, err := filesystem.NewFileSystem(user) - if err != nil { - return err - } - defer fs.Recycle() - - // 创建上下文环境 - file := &fsctx.FileStream{ - Size: monitor.Task.TotalSize, - } - - // 验证用户容量 - if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil { - return err - } - - // 验证每个文件 - for _, fileInfo := range monitor.Task.StatusInfo.Files { - if fileInfo.Selected == "true" { - // 创建上下文环境 - fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64) - file := &fsctx.FileStream{ - Size: fileSize, - Name: filepath.Base(fileInfo.Path), - } - if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil { - return err - } - } - - } - - return nil -} - -// Error 任务下载出错处理,返回是否中断监控 -func (monitor *Monitor) Error(status rpc.StatusInfo) bool { - monitor.setErrorStatus(errors.New(status.ErrorMessage)) - - // 清理临时文件 - monitor.RemoveTempFolder() - - return true -} - -// RemoveTempFolder 清理下载临时目录 -func (monitor *Monitor) RemoveTempFolder() { - monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task) -} - -// Complete 完成下载,返回是否中断监控 -func (monitor *Monitor) Complete(pool task.Pool) bool { - // 未开始转存,提交转存任务 - if monitor.Task.TaskID == 0 { - return monitor.transfer(pool) - } - - // 做种完成 - if common.GetStatus(monitor.Task.StatusInfo) == common.Complete { - transferTask, err := model.GetTasksByID(monitor.Task.TaskID) - if err != nil { - monitor.setErrorStatus(err) - monitor.RemoveTempFolder() - return true - } - - // 转存完成,回收下载目录 - if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error { - job, err := task.NewRecycleTask(monitor.Task) - if err != nil { - monitor.setErrorStatus(err) - monitor.RemoveTempFolder() - return true - } - - // 提交回收任务 - pool.Submit(job) - - return true - } - } - - return false -} - -func (monitor *Monitor) transfer(pool task.Pool) bool { - // 创建中转任务 - file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) - sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) - for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ { - fileInfo := monitor.Task.StatusInfo.Files[i] - if fileInfo.Selected == "true" { - file = append(file, fileInfo.Path) - size, _ := strconv.ParseUint(fileInfo.Length, 10, 64) - sizes[fileInfo.Path] = size - } - } - - job, err := task.NewTransferTask( - monitor.Task.UserID, - file, - monitor.Task.Dst, - monitor.Task.Parent, - true, - monitor.node.ID(), - sizes, - ) - if err != nil { - monitor.setErrorStatus(err) - monitor.RemoveTempFolder() - return true - } - - // 提交中转任务 - pool.Submit(job) - - // 更新任务ID - monitor.Task.TaskID = job.Model().ID - monitor.Task.Save() - - return false -} - -func (monitor *Monitor) setErrorStatus(err error) { - monitor.Task.Status = common.Error - monitor.Task.Error = err.Error() - monitor.Task.Save() -} diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go deleted file mode 100644 index a6be586a..00000000 --- a/pkg/aria2/monitor/monitor_test.go +++ /dev/null @@ -1,447 +0,0 @@ -package monitor - -import ( - "database/sql" - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestNewMonitor(t *testing.T) { - a := assert.New(t) - mockMQ := mq.NewMQ() - - // node not available - { - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", uint(1)).Return(nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task := &model.Download{ - Model: gorm.Model{ID: 1}, - } - NewMonitor(task, mockPool, mockMQ) - mockPool.AssertExpectations(t) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(task.Error) - } - - // success - { - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", uint(1)).Return(mockNode) - - task := &model.Download{ - Model: gorm.Model{ID: 1}, - } - NewMonitor(task, mockPool, mockMQ) - mockNode.AssertExpectations(t) - mockPool.AssertExpectations(t) - } - -} - -func TestMonitor_Loop(t *testing.T) { - a := assert.New(t) - mockMQ := mq.NewMQ() - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - retried: MAX_RETRY, - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - notifier: mockMQ.Subscribe("test", 1), - } - - // into interval loop - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - m.Loop(mockMQ) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) - } - - // into notifier loop - { - m.Task.Error = "" - mockMQ.Publish("test", mq.Message{}) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - m.Loop(mockMQ) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) - } -} - -func TestMonitor_UpdateFailedAfterRetry(t *testing.T) { - a := assert.New(t) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - for i := 0; i < MAX_RETRY; i++ { - a.False(m.Update()) - } - - mockNode.AssertExpectations(t) - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateMagentoFollow(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - FollowedBy: []string{"next"}, - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.False(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("next", m.Task.GID) - mockAria2.AssertExpectations(t) -} - -func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateCompleted(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "complete", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - mockNode.On("ID").Return(uint(1)) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateError(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "error", - ErrorMessage: "error", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateActive(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "active", - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.False(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateRemoved(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "removed", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.Equal(common.Canceled, m.Task.Status) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateUnknown(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "unknown", - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) { - a := assert.New(t) - status := rpc.StatusInfo{ - Status: "completed", - TotalLength: "100", - CompletedLength: "50", - DownloadSpeed: "20", - } - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := m.UpdateTaskInfo(status) - a.Error(err) - a.NoError(mock.ExpectationsWereMet()) - mockNode.AssertExpectations(t) -} - -func TestMonitor_ValidateFile(t *testing.T) { - a := assert.New(t) - m := &Monitor{ - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - TotalSize: 100, - }, - } - - // failed to create filesystem - { - m.Task.User = &model.User{ - Policy: model.Policy{ - Type: "random", - }, - } - a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile()) - } - - // User capacity not enough - { - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 99, - }, - Policy: model.Policy{ - Type: "local", - }, - } - a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile()) - } - - // single file too big - { - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 100, - }, - Policy: model.Policy{ - Type: "local", - MaxSize: 99, - }, - } - a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile()) - } - - // all pass - { - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 100, - }, - Policy: model.Policy{ - Type: "local", - MaxSize: 100, - }, - } - a.NoError(m.ValidateFile()) - } -} - -func TestMonitor_Complete(t *testing.T) { - a := assert.New(t) - mockNode := &mocks.NodeMock{} - mockNode.On("ID").Return(uint(1)) - mockPool := &mocks.TaskPoolMock{} - mockPool.On("Submit", testMock.Anything) - m := &Monitor{ - node: mockNode, - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - TotalSize: 100, - UserID: 9414, - }, - } - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) - - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4)) - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - - a.False(m.Complete(mockPool)) - m.Task.StatusInfo.Status = "complete" - a.True(m.Complete(mockPool)) - a.NoError(mock.ExpectationsWereMet()) - mockNode.AssertExpectations(t) - mockPool.AssertExpectations(t) -} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 32a7e917..cbc66418 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -2,18 +2,18 @@ package auth import ( "bytes" + "context" "fmt" - "io/ioutil" + "io" "net/http" "net/url" + "regexp" "sort" "strings" "time" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" ) var ( @@ -23,37 +23,59 @@ var ( ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil) ) -const CrHeaderPrefix = "X-Cr-" +const ( + TokenHeaderPrefixCr = "Bearer Cr " +) // General 通用的认证接口 +// Deprecated var General Auth -// Auth 鉴权认证 -type Auth interface { - // 对给定Body进行签名,expires为0表示永不过期 - Sign(body string, expires int64) string - // 对给定Body和Sign进行检查 - Check(body string, sign string) error -} +type ( + // Auth 鉴权认证 + Auth interface { + // 对给定Body进行签名,expires为0表示永不过期 + Sign(body string, expires int64) string + // 对给定Body和Sign进行检查 + Check(body string, sign string) error + } +) // SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、 // 请求正文、`X-Cr-`开头的header进行签名 -func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request { +func SignRequest(ctx context.Context, instance Auth, r *http.Request, expires *time.Time) *http.Request { + // 处理有效期 + expireTime := int64(0) + if expires != nil { + expireTime = expires.Unix() + } + + // 生成签名 + sign := instance.Sign(getSignContent(ctx, r), expireTime) + + // 将签名加到请求Header中 + r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign} + return r +} + +// SignRequestDeprecated 对PUT\POST等复杂HTTP请求签名,只会对URI部分、 +// 请求正文、`X-Cr-`开头的header进行签名 +func SignRequestDeprecated(instance Auth, r *http.Request, expires int64) *http.Request { // 处理有效期 if expires > 0 { expires += time.Now().Unix() } // 生成签名 - sign := instance.Sign(getSignContent(r), expires) + sign := instance.Sign(getSignContent(context.Background(), r), expires) // 将签名加到请求Header中 - r.Header["Authorization"] = []string{"Bearer " + sign} + r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign} return r } // CheckRequest 对复杂请求进行签名验证 -func CheckRequest(instance Auth, r *http.Request) error { +func CheckRequest(ctx context.Context, instance Auth, r *http.Request) error { var ( sign []string ok bool @@ -61,41 +83,71 @@ func CheckRequest(instance Auth, r *http.Request) error { if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 { return ErrAuthHeaderMissing } - sign[0] = strings.TrimPrefix(sign[0], "Bearer ") + sign[0] = strings.TrimPrefix(sign[0], TokenHeaderPrefixCr) + + return instance.Check(getSignContent(ctx, r), sign[0]) +} + +func isUploadDataRequest(r *http.Request) bool { + return strings.Contains(r.URL.Path, constants.APIPrefix+"/slave/upload/") && r.Method != http.MethodPut - return instance.Check(getSignContent(r), sign[0]) } // getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API, // 则不对正文签名。返回待签名/验证的字符串 -func getSignContent(r *http.Request) (rawSignString string) { +func getSignContent(ctx context.Context, r *http.Request) (rawSignString string) { // 读取所有body正文 var body = []byte{} - if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") { + if !isUploadDataRequest(r) { if r.Body != nil { - body, _ = ioutil.ReadAll(r.Body) + body, _ = io.ReadAll(r.Body) _ = r.Body.Close() - r.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.Body = io.NopCloser(bytes.NewReader(body)) } } // 决定要签名的header var signedHeader []string for k, _ := range r.Header { - if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" { + if strings.HasPrefix(k, constants.CrHeaderPrefix) && k != constants.CrHeaderPrefix+"Filename" { signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k))) } } sort.Strings(signedHeader) // 读取所有待签名Header - rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body)) + rawSignString = serializer.NewRequestSignString(getUrlSignContent(ctx, r.URL), strings.Join(signedHeader, "&"), string(body)) return rawSignString } -// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证 -func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) { +// SignURI 对URI进行签名 +func SignURI(ctx context.Context, instance Auth, uri string, expires *time.Time) (*url.URL, error) { + // 处理有效期 + expireTime := int64(0) + if expires != nil { + expireTime = expires.Unix() + } + + base, err := url.Parse(uri) + if err != nil { + return nil, err + } + + // 生成签名 + sign := instance.Sign(getUrlSignContent(ctx, base), expireTime) + + // 将签名加到URI中 + queries := base.Query() + queries.Set("sign", sign) + base.RawQuery = queries.Encode() + + return base, nil +} + +// SignURIDeprecated 对URI进行签名,签名只针对Path部分,query部分不做验证 +// Deprecated +func SignURIDeprecated(instance Auth, uri string, expires int64) (*url.URL, error) { // 处理有效期 if expires != 0 { expires += time.Now().Unix() @@ -118,28 +170,55 @@ func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) { } // CheckURI 对URI进行鉴权 -func CheckURI(instance Auth, url *url.URL) error { +func CheckURI(ctx context.Context, instance Auth, url *url.URL) error { //获取待验证的签名正文 queries := url.Query() sign := queries.Get("sign") queries.Del("sign") url.RawQuery = queries.Encode() - return instance.Check(url.Path, sign) + return instance.Check(getUrlSignContent(ctx, url), sign) } -// Init 初始化通用鉴权器 -func Init() { - var secretKey string - if conf.SystemConfig.Mode == "master" { - secretKey = model.GetSettingByName("secret_key") - } else { - secretKey = conf.SlaveConfig.Secret - if secretKey == "" { - util.Log().Panic("SlaveSecret is not set, please specify it in config file.") +func RedactSensitiveValues(errorMessage string) string { + // Regular expression to match URLs + urlRegex := regexp.MustCompile(`https?://[^\s]+`) + // Find all URLs in the error message + urls := urlRegex.FindAllString(errorMessage, -1) + + for _, urlStr := range urls { + // Parse the URL + parsedURL, err := url.Parse(urlStr) + if err != nil { + continue } + + // Get the query parameters + queryParams := parsedURL.Query() + + // Redact the 'sign' parameter if it exists + if _, exists := queryParams["sign"]; exists { + queryParams.Set("sign", "REDACTED") + parsedURL.RawQuery = queryParams.Encode() + } + + // Replace the original URL with the redacted one in the error message + errorMessage = strings.Replace(errorMessage, urlStr, parsedURL.String(), -1) } - General = HMACAuth{ - SecretKey: []byte(secretKey), - } + + return errorMessage +} + +func getUrlSignContent(ctx context.Context, url *url.URL) string { + // host := url.Host + // if host == "" { + // reqInfo := requestinfo.RequestInfoFromContext(ctx) + // if reqInfo != nil { + // host = reqInfo.Host + // } + // } + // host = strings.TrimSuffix(host, "/") + // // remove port if it exists + // host = strings.Split(host, ":")[0] + return url.Path } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go deleted file mode 100644 index 42c5603e..00000000 --- a/pkg/auth/auth_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package auth - -import ( - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" -) - -func TestSignURI(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 成功 - { - sign, err := SignURI(General, "/api/v3/something?id=1", 0) - asserts.NoError(err) - queries := sign.Query() - asserts.Equal("1", queries.Get("id")) - asserts.NotEmpty(queries.Get("sign")) - } - - // URI解码失败 - { - sign, err := SignURI(General, "://dg.;'f]gh./'", 0) - asserts.Error(err) - asserts.Nil(sign) - } -} - -func TestCheckURI(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 成功 - { - sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10) - asserts.NoError(err) - asserts.NoError(CheckURI(General, sign)) - } - - // 过期 - { - sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1) - asserts.NoError(err) - asserts.Error(CheckURI(General, sign)) - } -} - -func TestSignRequest(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 非上传请求 - { - req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body.")) - asserts.NoError(err) - req = SignRequest(General, req, 0) - asserts.NotEmpty(req.Header["Authorization"]) - } - - // 上传请求 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/slave/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req.Header["X-Cr-Policy"] = []string{"I am Policy"} - req = SignRequest(General, req, 10) - asserts.NotEmpty(req.Header["Authorization"]) - } -} - -func TestCheckRequest(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 缺少请求头 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - err = CheckRequest(General, req) - asserts.Error(err) - asserts.Equal(ErrAuthHeaderMissing, err) - } - - // 非上传请求 验证成功 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req = SignRequest(General, req, 0) - err = CheckRequest(General, req) - asserts.NoError(err) - } - - // 上传请求 验证成功 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req.Header["X-Cr-Policy"] = []string{"I am Policy"} - req = SignRequest(General, req, 0) - err = CheckRequest(General, req) - asserts.NoError(err) - } - - // 非上传请求 失败 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req = SignRequest(General, req, 0) - req.Body = ioutil.NopCloser(strings.NewReader("2333")) - err = CheckRequest(General, req) - asserts.Error(err) - } -} diff --git a/pkg/auth/hmac.go b/pkg/auth/hmac.go index 50849cc3..2654249b 100644 --- a/pkg/auth/hmac.go +++ b/pkg/auth/hmac.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" ) // HMACAuth HMAC算法鉴权 @@ -39,7 +41,7 @@ func (auth HMACAuth) Check(body string, sign string) error { // 验证是否过期 expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64) if err != nil { - return ErrAuthFailed.WithError(err) + return serializer.NewError(serializer.CodeInvalidSign, "sign expired", nil) } // 如果签名过期 if expires < time.Now().Unix() && expires != 0 { @@ -48,7 +50,7 @@ func (auth HMACAuth) Check(body string, sign string) error { // 验证签名 if auth.Sign(body, expires) != sign { - return ErrAuthFailed + return serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil) } return nil } diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go deleted file mode 100644 index 706f617e..00000000 --- a/pkg/auth/hmac_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package auth - -import ( - "database/sql" - "fmt" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -func TestMain(m *testing.M) { - // 设置gin为测试模式 - gin.SetMode(gin.TestMode) - - // 初始化sqlmock - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - - mockDB, _ := gorm.Open("mysql", db) - model.DB = mockDB - defer db.Close() - - m.Run() -} - -func TestHMACAuth_Sign(t *testing.T) { - asserts := assert.New(t) - auth := HMACAuth{ - SecretKey: []byte(util.RandStringRunes(256)), - } - - asserts.NotEmpty(auth.Sign("content", 0)) -} - -func TestHMACAuth_Check(t *testing.T) { - asserts := assert.New(t) - auth := HMACAuth{ - SecretKey: []byte(util.RandStringRunes(256)), - } - - // 正常,永不过期 - { - sign := auth.Sign("content", 0) - asserts.NoError(auth.Check("content", sign)) - } - - // 过期 - { - sign := auth.Sign("content", 1) - asserts.Error(auth.Check("content", sign)) - } - - // 签名格式错误 - { - sign := auth.Sign("content", 1) - asserts.Error(auth.Check("content", sign+":")) - } - - // 过期日期格式错误 - { - asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed")) - } - - // 签名有误 - { - asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10))) - } -} - -func TestInit(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312")) - Init() - asserts.NoError(mock.ExpectationsWereMet()) - - // slave模式 - conf.SystemConfig.Mode = "slave" - asserts.Panics(func() { - Init() - }) -} diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go new file mode 100644 index 00000000..9fe0d0c2 --- /dev/null +++ b/pkg/auth/jwt.go @@ -0,0 +1,200 @@ +package auth + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "fmt" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +type TokenAuth interface { + // Issue issues a new pair of credentials for the given user. + Issue(ctx context.Context, u *ent.User) (*Token, error) + // VerifyAndRetrieveUser verifies the given token and inject the user into current context. + // Returns if upper caller should continue process other session provider. + VerifyAndRetrieveUser(c *gin.Context) (bool, error) + // Refresh refreshes the given refresh token and returns a new pair of credentials. + Refresh(ctx context.Context, refreshToken string) (*Token, error) +} + +// Token stores token pair for authentication +type Token struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccessExpires time.Time `json:"access_expires"` + RefreshExpires time.Time `json:"refresh_expires"` + + UID int `json:"-"` +} + +type ( + TokenType string + TokenIDContextKey struct{} +) + +var ( + TokenTypeAccess = TokenType("access") + TokenTypeRefresh = TokenType("refresh") + + ErrInvalidRefreshToken = errors.New("invalid refresh token") + ErrUserNotFound = errors.New("user not found") +) + +const ( + AuthorizationHeader = "Authorization" + TokenHeaderPrefix = "Bearer " +) + +type Claims struct { + TokenType TokenType `json:"token_type"` + jwt.RegisteredClaims + StateHash []byte `json:"state_hash,omitempty"` +} + +// NewTokenAuth creates a new token based auth provider. +func NewTokenAuth(idEncoder hashid.Encoder, s setting.Provider, secret []byte, userClient inventory.UserClient, l logging.Logger) TokenAuth { + return &tokenAuth{ + idEncoder: idEncoder, + s: s, + secret: secret, + userClient: userClient, + l: l, + } +} + +type tokenAuth struct { + l logging.Logger + idEncoder hashid.Encoder + s setting.Provider + secret []byte + userClient inventory.UserClient +} + +func (t *tokenAuth) Refresh(ctx context.Context, refreshToken string) (*Token, error) { + token, err := jwt.ParseWithClaims(refreshToken, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return t.secret, nil + }) + + if err != nil { + return nil, fmt.Errorf("invalid refresh token: %w", err) + } + + claims, ok := token.Claims.(*Claims) + if !ok || claims.TokenType != TokenTypeRefresh { + return nil, ErrInvalidRefreshToken + } + + uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID) + if err != nil { + return nil, ErrUserNotFound + } + + expectedUser, err := t.userClient.GetActiveByID(ctx, uid) + if err != nil { + return nil, ErrUserNotFound + } + + // Check if user changed password or revoked session + expectedHash := t.hashUserState(ctx, expectedUser) + if !bytes.Equal(claims.StateHash, expectedHash[:]) { + return nil, ErrInvalidRefreshToken + } + + return t.Issue(ctx, expectedUser) +} + +func (t *tokenAuth) VerifyAndRetrieveUser(c *gin.Context) (bool, error) { + headerVal := c.GetHeader(AuthorizationHeader) + if strings.HasPrefix(headerVal, TokenHeaderPrefixCr) { + // This is an HMAC auth header, skip JWT verification + return false, nil + } + + tokenString := strings.TrimPrefix(headerVal, TokenHeaderPrefix) + if tokenString == "" { + return true, nil + } + + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return t.secret, nil + }) + + if err != nil { + t.l.Warning("Failed to parse jwt token: %s", err) + return false, nil + } + + claims, ok := token.Claims.(*Claims) + if !ok || claims.TokenType != TokenTypeAccess { + return false, serializer.NewError(serializer.CodeCredentialInvalid, "Invalid token type", nil) + } + + uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID) + if err != nil { + return false, serializer.NewError(serializer.CodeNotFound, "User not found", err) + } + + util.WithValue(c, inventory.UserIDCtx{}, uid) + return false, nil +} + +func (t *tokenAuth) Issue(ctx context.Context, u *ent.User) (*Token, error) { + uidEncoded := hashid.EncodeUserID(t.idEncoder, u.ID) + tokenSettings := t.s.TokenAuth(ctx) + issueDate := time.Now() + accessTokenExpired := time.Now().Add(tokenSettings.AccessTokenTTL) + refreshTokenExpired := time.Now().Add(tokenSettings.RefreshTokenTTL) + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ + TokenType: TokenTypeAccess, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: uidEncoded, + NotBefore: jwt.NewNumericDate(issueDate), + ExpiresAt: jwt.NewNumericDate(accessTokenExpired), + }, + }).SignedString(t.secret) + if err != nil { + return nil, fmt.Errorf("faield to sign access token: %w", err) + } + + userHash := t.hashUserState(ctx, u) + refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ + TokenType: TokenTypeRefresh, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: uidEncoded, + NotBefore: jwt.NewNumericDate(issueDate), + ExpiresAt: jwt.NewNumericDate(refreshTokenExpired), + }, + StateHash: userHash[:], + }).SignedString(t.secret) + if err != nil { + return nil, fmt.Errorf("faield to sign refresh token: %w", err) + } + + return &Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + AccessExpires: accessTokenExpired, + RefreshExpires: refreshTokenExpired, + UID: u.ID, + }, nil +} + +// hashUserState returns a hash string for user state for critical fields, it is used +// to detect refresh token revocation after user changed password. +func (t *tokenAuth) hashUserState(ctx context.Context, u *ent.User) [32]byte { + return sha256.Sum256([]byte(fmt.Sprintf("%s/%s/%s", u.Email, u.Password, t.s.SiteBasic(ctx).ID))) +} diff --git a/pkg/auth/requestinfo/requestinfo.go b/pkg/auth/requestinfo/requestinfo.go new file mode 100644 index 00000000..55c09929 --- /dev/null +++ b/pkg/auth/requestinfo/requestinfo.go @@ -0,0 +1,25 @@ +package requestinfo + +import ( + "context" +) + +// RequestInfoCtx context key for RequestInfo +type RequestInfoCtx struct{} + +// RequestInfoFromContext retrieves RequestInfo from context +func RequestInfoFromContext(ctx context.Context) *RequestInfo { + v, ok := ctx.Value(RequestInfoCtx{}).(*RequestInfo) + if !ok { + return nil + } + + return v +} + +// RequestInfo store request info for audit +type RequestInfo struct { + Host string + IP string + UserAgent string +} diff --git a/pkg/authn/auth.go b/pkg/authn/auth.go deleted file mode 100644 index 5c5b4b71..00000000 --- a/pkg/authn/auth.go +++ /dev/null @@ -1,16 +0,0 @@ -package authn - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/duo-labs/webauthn/webauthn" -) - -// NewAuthnInstance 新建Authn实例 -func NewAuthnInstance() (*webauthn.WebAuthn, error) { - base := model.GetSiteURL() - return webauthn.New(&webauthn.Config{ - RPDisplayName: model.GetSettingByName("siteName"), // Display Name for your site - RPID: base.Hostname(), // Generally the FQDN for your site - RPOrigin: base.String(), // The origin URL for WebAuthn requests - }) -} diff --git a/pkg/authn/auth_test.go b/pkg/authn/auth_test.go deleted file mode 100644 index 3df60cfc..00000000 --- a/pkg/authn/auth_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package authn - -import ( - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/stretchr/testify/assert" -) - -func TestInit(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_siteURL", "http://cloudreve.org", 0) - cache.Set("setting_siteName", "Cloudreve", 0) - res, err := NewAuthnInstance() - asserts.NotNil(res) - asserts.NoError(err) -} diff --git a/pkg/boolset/boolset.go b/pkg/boolset/boolset.go new file mode 100644 index 00000000..65ecd9d1 --- /dev/null +++ b/pkg/boolset/boolset.go @@ -0,0 +1,86 @@ +package boolset + +import ( + "database/sql/driver" + "encoding/base64" + "errors" + "golang.org/x/exp/constraints" +) + +var ( + ErrValueNotSupported = errors.New("value not supported") +) + +type BooleanSet []byte + +// FromString convert from base64 encoded boolset. +func FromString(data string) (*BooleanSet, error) { + raw, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + + b := BooleanSet(raw) + return &b, nil +} + +func (b *BooleanSet) UnmarshalBinary(data []byte) error { + *b = make(BooleanSet, len(data)) + copy(*b, data) + return nil +} + +func (b *BooleanSet) MarshalBinary() (data []byte, err error) { + return *b, nil +} + +func (b *BooleanSet) String() (data string, err error) { + raw, err := b.MarshalBinary() + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(raw), nil +} + +func (b *BooleanSet) Enabled(flag int) bool { + if flag >= len(*b)*8 { + return false + } + + return (*b)[flag/8]&(1< 0 && item.Expires < time.Now().Unix() { util.Log().Debug("Cache %q is garbage collected.", key.(string)) @@ -67,25 +68,33 @@ func (store *MemoStore) GarbageCollect() { } // NewMemoStore 新建内存存储 -func NewMemoStore() *MemoStore { - return &MemoStore{ +func NewMemoStore(persistFile string, l logging.Logger) *MemoStore { + store := &MemoStore{ Store: &sync.Map{}, } + + if persistFile != "" { + if err := store.Restore(persistFile); err != nil { + l.Warning("Failed to restore cache from disk: %s", err) + } + } + + return store } // Set 存储值 -func (store *MemoStore) Set(key string, value interface{}, ttl int) error { +func (store *MemoStore) Set(key string, value any, ttl int) error { store.Store.Store(key, newItem(value, ttl)) return nil } // Get 取值 -func (store *MemoStore) Get(key string) (interface{}, bool) { +func (store *MemoStore) Get(key string) (any, bool) { return getValue(store.Store.Load(key)) } // Gets 批量取值 -func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { +func (store *MemoStore) Gets(keys []string, prefix string) (map[string]any, []string) { var res = make(map[string]interface{}) var notFound = make([]string, 0, len(keys)) @@ -101,7 +110,7 @@ func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface } // Sets 批量设置值 -func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error { +func (store *MemoStore) Sets(values map[string]any, prefix string) error { for key, value := range values { store.Store.Store(prefix+key, newItem(value, 0)) } @@ -109,17 +118,27 @@ func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error } // Delete 批量删除值 -func (store *MemoStore) Delete(keys []string, prefix string) error { +func (store *MemoStore) Delete(prefix string, keys ...string) error { for _, key := range keys { store.Store.Delete(prefix + key) } + + // No key is presented, delete all entries with given prefix. + if len(keys) == 0 { + store.Store.Range(func(key, value any) bool { + if k, ok := key.(string); ok && strings.HasPrefix(k, prefix) { + store.Store.Delete(key) + } + return true + }) + } return nil } // Persist write memory store into cache func (store *MemoStore) Persist(path string) error { persisted := make(map[string]itemWithTTL) - store.Store.Range(func(key, value interface{}) bool { + store.Store.Range(func(key, value any) bool { v, ok := store.Store.Load(key) if _, ok := getValue(v, ok); ok { persisted[key.(string)] = v.(itemWithTTL) @@ -173,3 +192,12 @@ func (store *MemoStore) Restore(path string) error { util.Log().Info("Restored %d items from %q into memory cache.", loaded, path) return nil } + +func (store *MemoStore) DeleteAll() error { + store.Store.Range(func(key any, value any) bool { + store.Store.Delete(key) + return true + }) + + return nil +} diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go index be905770..2efbb703 100644 --- a/pkg/cache/memo_test.go +++ b/pkg/cache/memo_test.go @@ -2,7 +2,6 @@ package cache import ( "github.com/stretchr/testify/assert" - "path/filepath" "testing" "time" ) @@ -146,46 +145,3 @@ func TestMemoStore_GarbageCollect(t *testing.T) { _, ok := store.Get("test") asserts.False(ok) } - -func TestMemoStore_PersistFailed(t *testing.T) { - a := assert.New(t) - store := NewMemoStore() - type testStruct struct{ v string } - store.Set("test", 1, 0) - store.Set("test2", testStruct{v: "test"}, 0) - err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")) - a.Error(err) -} - -func TestMemoStore_PersistAndRestore(t *testing.T) { - a := assert.New(t) - store := NewMemoStore() - store.Set("test", 1, 0) - // already expired - store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1}) - // expired after persist - store.Set("test3", 1, 1) - temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed") - - // Persist - err := store.Persist(temp) - a.NoError(err) - a.FileExists(temp) - - time.Sleep(2 * time.Second) - // Restore - store2 := NewMemoStore() - err = store2.Restore(temp) - a.NoError(err) - test, testOk := store2.Get("test") - a.EqualValues(1, test) - a.True(testOk) - test2, test2Ok := store2.Get("test2") - a.Nil(test2) - a.False(test2Ok) - test3, test3Ok := store2.Get("test3") - a.Nil(test3) - a.False(test3Ok) - - a.NoFileExists(temp) -} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 5c776a0a..0cdee1a3 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -3,10 +3,10 @@ package cache import ( "bytes" "encoding/gob" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" "strconv" "time" - "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gomodule/redigo/redis" ) @@ -19,7 +19,7 @@ type item struct { Value interface{} } -func serializer(value interface{}) ([]byte, error) { +func serializer(value any) ([]byte, error) { var buffer bytes.Buffer enc := gob.NewEncoder(&buffer) storeValue := item{ @@ -32,7 +32,7 @@ func serializer(value interface{}) ([]byte, error) { return buffer.Bytes(), nil } -func deserializer(value []byte) (interface{}, error) { +func deserializer(value []byte) (any, error) { var res item buffer := bytes.NewReader(value) dec := gob.NewDecoder(buffer) @@ -44,7 +44,7 @@ func deserializer(value []byte) (interface{}, error) { } // NewRedisStore 创建新的redis存储 -func NewRedisStore(size int, network, address, user, password, database string) *RedisStore { +func NewRedisStore(l logging.Logger, size int, network, address, user, password, database string) *RedisStore { return &RedisStore{ pool: &redis.Pool{ MaxIdle: size, @@ -63,11 +63,11 @@ func NewRedisStore(size int, network, address, user, password, database string) network, address, redis.DialDatabase(db), - redis.DialUsername(user), redis.DialPassword(password), + redis.DialUsername(user), ) if err != nil { - util.Log().Panic("Failed to create Redis connection: %s", err) + l.Panic("Failed to create Redis connection: %s", err) } return c, nil }, @@ -76,7 +76,7 @@ func NewRedisStore(size int, network, address, user, password, database string) } // Set 存储值 -func (store *RedisStore) Set(key string, value interface{}, ttl int) error { +func (store *RedisStore) Set(key string, value any, ttl int) error { rc := store.pool.Get() defer rc.Close() @@ -103,7 +103,7 @@ func (store *RedisStore) Set(key string, value interface{}, ttl int) error { } // Get 取值 -func (store *RedisStore) Get(key string) (interface{}, bool) { +func (store *RedisStore) Get(key string) (any, bool) { rc := store.pool.Get() defer rc.Close() if rc.Err() != nil { @@ -125,7 +125,7 @@ func (store *RedisStore) Get(key string) (interface{}, bool) { } // Gets 批量取值 -func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { +func (store *RedisStore) Gets(keys []string, prefix string) (map[string]any, []string) { rc := store.pool.Get() defer rc.Close() if rc.Err() != nil { @@ -142,7 +142,7 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac return nil, keys } - var res = make(map[string]interface{}) + var res = make(map[string]any) var missed = make([]string, 0, len(keys)) for key, value := range v { @@ -158,13 +158,13 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac } // Sets 批量设置值 -func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error { +func (store *RedisStore) Sets(values map[string]any, prefix string) error { rc := store.pool.Get() defer rc.Close() if rc.Err() != nil { return rc.Err() } - var setValues = make(map[string]interface{}) + var setValues = make(map[string]any) // 编码待设置值 for key, value := range values { @@ -184,7 +184,7 @@ func (store *RedisStore) Sets(values map[string]interface{}, prefix string) erro } // Delete 批量删除给定的键 -func (store *RedisStore) Delete(keys []string, prefix string) error { +func (store *RedisStore) Delete(prefix string, keys ...string) error { rc := store.pool.Get() defer rc.Close() if rc.Err() != nil { @@ -196,10 +196,24 @@ func (store *RedisStore) Delete(keys []string, prefix string) error { keys[i] = prefix + keys[i] } - _, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...) - if err != nil { - return err + // No key is presented, delete all keys with given prefix + if len(keys) == 0 { + // Fetch all key with given prefix + allPrefixKeys, err := redis.Strings(rc.Do("KEYS", prefix+"*")) + if err != nil { + return err + } + + keys = allPrefixKeys } + + if len(keys) > 0 { + _, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...) + if err != nil { + return err + } + } + return nil } diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go index c9f16923..609eba16 100644 --- a/pkg/cache/redis_test.go +++ b/pkg/cache/redis_test.go @@ -13,16 +13,16 @@ import ( func TestNewRedisStore(t *testing.T) { asserts := assert.New(t) - store := NewRedisStore(10, "tcp", "", "", "", "0") + store := NewRedisStore(10, "tcp", "", "", "0") asserts.NotNil(store) - asserts.Panics(func() { - store.pool.Dial() - }) + conn, err := store.pool.Dial() + asserts.Nil(conn) + asserts.Error(err) testConn := redigomock.NewConn() cmd := testConn.Command("PING").Expect("PONG") - err := store.pool.TestOnBorrow(testConn, time.Now()) + err = store.pool.TestOnBorrow(testConn, time.Now()) if testConn.Stats(cmd) != 1 { fmt.Println("Command was not used") return diff --git a/pkg/cluster/controller.go b/pkg/cluster/controller.go deleted file mode 100644 index 85fb178c..00000000 --- a/pkg/cluster/controller.go +++ /dev/null @@ -1,209 +0,0 @@ -package cluster - -import ( - "bytes" - "encoding/gob" - "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/jinzhu/gorm" - "net/url" - "sync" -) - -var DefaultController Controller - -// Controller controls communications between master and slave -type Controller interface { - // Handle heartbeat sent from master - HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error) - - // Get Aria2 Instance by master node ID - GetAria2Instance(string) (common.Aria2, error) - - // Send event change message to master node - SendNotification(string, string, mq.Message) error - - // Submit async task into task pool - SubmitTask(string, interface{}, string, func(interface{})) error - - // Get master node info - GetMasterInfo(string) (*MasterInfo, error) - - // Get master Oauth based policy credential - GetPolicyOauthToken(string, uint) (string, error) -} - -type slaveController struct { - masters map[string]MasterInfo - lock sync.RWMutex -} - -// info of master node -type MasterInfo struct { - ID string - TTL int - URL *url.URL - // used to invoke aria2 rpc calls - Instance Node - Client request.Client - - jobTracker map[string]bool -} - -func InitController() { - DefaultController = &slaveController{ - masters: make(map[string]MasterInfo), - } - gob.Register(rpc.StatusInfo{}) -} - -func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) { - c.lock.Lock() - defer c.lock.Unlock() - - req.Node.AfterFind() - - // close old node if exist - origin, ok := c.masters[req.SiteID] - - if (ok && req.IsUpdate) || !ok { - if ok { - origin.Instance.Kill() - } - - masterUrl, err := url.Parse(req.SiteURL) - if err != nil { - return serializer.NodePingResp{}, err - } - - c.masters[req.SiteID] = MasterInfo{ - ID: req.SiteID, - URL: masterUrl, - TTL: req.CredentialTTL, - Client: request.NewClient( - request.WithEndpoint(masterUrl.String()), - request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)), - request.WithCredential(auth.HMACAuth{ - SecretKey: []byte(req.Node.MasterKey), - }, int64(req.CredentialTTL)), - ), - jobTracker: make(map[string]bool), - Instance: NewNodeFromDBModel(&model.Node{ - Model: gorm.Model{ID: req.Node.ID}, - MasterKey: req.Node.MasterKey, - Type: model.MasterNodeType, - Aria2Enabled: req.Node.Aria2Enabled, - Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized, - }), - } - } - - return serializer.NodePingResp{}, nil -} - -func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { - c.lock.RLock() - defer c.lock.RUnlock() - - if node, ok := c.masters[id]; ok { - return node.Instance.GetAria2Instance(), nil - } - - return nil, ErrMasterNotFound -} - -func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error { - c.lock.RLock() - - if node, ok := c.masters[id]; ok { - c.lock.RUnlock() - - body := bytes.Buffer{} - enc := gob.NewEncoder(&body) - if err := enc.Encode(&msg); err != nil { - return err - } - - res, err := node.Client.Request( - "PUT", - fmt.Sprintf("/api/v3/slave/notification/%s", subject), - &body, - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - return nil - } - - c.lock.RUnlock() - return ErrMasterNotFound -} - -// SubmitTask 提交异步任务 -func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error { - c.lock.RLock() - defer c.lock.RUnlock() - - if node, ok := c.masters[id]; ok { - if _, ok := node.jobTracker[hash]; ok { - // 任务已存在,直接返回 - return nil - } - - node.jobTracker[hash] = true - submitter(job) - return nil - } - - return ErrMasterNotFound -} - -// GetMasterInfo 获取主机节点信息 -func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { - c.lock.RLock() - defer c.lock.RUnlock() - - if node, ok := c.masters[id]; ok { - return &node, nil - } - - return nil, ErrMasterNotFound -} - -// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证 -func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) { - c.lock.RLock() - - if node, ok := c.masters[id]; ok { - c.lock.RUnlock() - - res, err := node.Client.Request( - "GET", - fmt.Sprintf("/api/v3/slave/credential/%d", policyID), - nil, - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return "", err - } - - if res.Code != 0 { - return "", serializer.NewErrorFromResponse(res) - } - - return res.Data.(string), nil - } - - c.lock.RUnlock() - return "", ErrMasterNotFound -} diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go deleted file mode 100644 index 42d83627..00000000 --- a/pkg/cluster/controller_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package cluster - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" -) - -func TestInitController(t *testing.T) { - assert.NotPanics(t, func() { - InitController() - }) -} - -func TestSlaveController_HandleHeartBeat(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: make(map[string]MasterInfo), - } - - // first heart beat - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - Node: &model.Node{}, - }) - a.NoError(err) - - _, err = c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "2", - Node: &model.Node{}, - }) - a.NoError(err) - - a.Len(c.masters, 2) - } - - // second heart beat, no fresh - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - SiteURL: "http://127.0.0.1", - Node: &model.Node{}, - }) - a.NoError(err) - a.Len(c.masters, 2) - a.Empty(c.masters["1"].URL) - } - - // second heart beat, fresh - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - IsUpdate: true, - SiteURL: "http://127.0.0.1", - Node: &model.Node{}, - }) - a.NoError(err) - a.Len(c.masters, 2) - a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) - } - - // second heart beat, fresh, url illegal - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - IsUpdate: true, - SiteURL: string([]byte{0x7f}), - Node: &model.Node{}, - }) - a.Error(err) - a.Len(c.masters, 2) - a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) - } -} - -type nodeMock struct { - testMock.Mock -} - -func (n nodeMock) Init(node *model.Node) { - n.Called(node) -} - -func (n nodeMock) IsFeatureEnabled(feature string) bool { - args := n.Called(feature) - return args.Bool(0) -} - -func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) { - n.Called(callback) -} - -func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { - args := n.Called(req) - return args.Get(0).(*serializer.NodePingResp), args.Error(1) -} - -func (n nodeMock) IsActive() bool { - args := n.Called() - return args.Bool(0) -} - -func (n nodeMock) GetAria2Instance() common.Aria2 { - args := n.Called() - return args.Get(0).(common.Aria2) -} - -func (n nodeMock) ID() uint { - args := n.Called() - return args.Get(0).(uint) -} - -func (n nodeMock) Kill() { - n.Called() -} - -func (n nodeMock) IsMater() bool { - args := n.Called() - return args.Bool(0) -} - -func (n nodeMock) MasterAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n nodeMock) SlaveAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n nodeMock) DBModel() *model.Node { - args := n.Called() - return args.Get(0).(*model.Node) -} - -func TestSlaveController_GetAria2Instance(t *testing.T) { - a := assert.New(t) - mockNode := &nodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Instance: mockNode}, - }, - } - - // node node found - { - res, err := c.GetAria2Instance("2") - a.Nil(res) - a.Equal(ErrMasterNotFound, err) - } - - // node found - { - res, err := c.GetAria2Instance("1") - a.NotNil(res) - a.NoError(err) - mockNode.AssertExpectations(t) - } - -} - -type requestMock struct { - testMock.Mock -} - -func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - return r.Called(method, target, body, opts).Get(0).(*request.Response) -} - -func TestSlaveController_SendNotification(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{})) - } - - // gob encode error - { - type randomType struct{} - a.Error(c.SendNotification("1", "", mq.Message{ - Content: randomType{}, - })) - } - - // return none 200 - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{StatusCode: http.StatusConflict}, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.Error(c.SendNotification("1", "s1", mq.Message{})) - mockRequest.AssertExpectations(t) - } - - // master return error - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code) - mockRequest.AssertExpectations(t) - } - - // success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.NoError(c.SendNotification("1", "s3", mq.Message{})) - mockRequest.AssertExpectations(t) - } -} - -func TestSlaveController_SubmitTask(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": { - jobTracker: map[string]bool{}, - }, - }, - } - - // node not exit - { - a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil)) - } - - // success - { - submitted := false - a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { - submitted = true - })) - a.True(submitted) - } - - // job already submitted - { - submitted := false - a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { - submitted = true - })) - a.False(submitted) - } -} - -func TestSlaveController_GetMasterInfo(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - res, err := c.GetMasterInfo("2") - a.Equal(ErrMasterNotFound, err) - a.Nil(res) - } - - // success - { - res, err := c.GetMasterInfo("1") - a.NoError(err) - a.NotNil(res) - } -} - -func TestSlaveController_GetOneDriveToken(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - res, err := c.GetPolicyOauthToken("2", 1) - a.Equal(ErrMasterNotFound, err) - a.Empty(res) - } - - // return none 200 - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{StatusCode: http.StatusConflict}, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.Error(err) - a.Empty(res) - mockRequest.AssertExpectations(t) - } - - // master return error - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.Equal(1, err.(serializer.AppError).Code) - a.Empty(res) - mockRequest.AssertExpectations(t) - } - - // success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.NoError(err) - a.Equal("expected", res) - mockRequest.AssertExpectations(t) - } - -} diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go deleted file mode 100644 index acd21d33..00000000 --- a/pkg/cluster/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package cluster - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -var ( - ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") - ErrIlegalPath = errors.New("path out of boundary of setting temp folder") - ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil) -) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go deleted file mode 100644 index 9c3dc619..00000000 --- a/pkg/cluster/master.go +++ /dev/null @@ -1,272 +0,0 @@ -package cluster - -import ( - "context" - "encoding/json" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" -) - -const ( - deleteTempFileDuration = 60 * time.Second - statusRetryDuration = 10 * time.Second -) - -type MasterNode struct { - Model *model.Node - aria2RPC rpcService - lock sync.RWMutex -} - -// RPCService 通过RPC服务的Aria2任务管理器 -type rpcService struct { - Caller rpc.Client - Initialized bool - - retryDuration time.Duration - deletePaddingDuration time.Duration - parent *MasterNode - options *clientOptions -} - -type clientOptions struct { - Options map[string]interface{} // 创建下载时额外添加的设置 -} - -// Init 初始化节点 -func (node *MasterNode) Init(nodeModel *model.Node) { - node.lock.Lock() - node.Model = nodeModel - node.aria2RPC.parent = node - node.aria2RPC.retryDuration = statusRetryDuration - node.aria2RPC.deletePaddingDuration = deleteTempFileDuration - node.lock.Unlock() - - node.lock.RLock() - if node.Model.Aria2Enabled { - node.lock.RUnlock() - node.aria2RPC.Init() - return - } - node.lock.RUnlock() -} - -func (node *MasterNode) ID() uint { - node.lock.RLock() - defer node.lock.RUnlock() - - return node.Model.ID -} - -func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { - return &serializer.NodePingResp{}, nil -} - -// IsFeatureEnabled 查询节点的某项功能是否启用 -func (node *MasterNode) IsFeatureEnabled(feature string) bool { - node.lock.RLock() - defer node.lock.RUnlock() - - switch feature { - case "aria2": - return node.Model.Aria2Enabled - default: - return false - } -} - -func (node *MasterNode) MasterAuthInstance() auth.Auth { - node.lock.RLock() - defer node.lock.RUnlock() - - return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} -} - -func (node *MasterNode) SlaveAuthInstance() auth.Auth { - node.lock.RLock() - defer node.lock.RUnlock() - - return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} -} - -// SubscribeStatusChange 订阅节点状态更改 -func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) { -} - -// IsActive 返回节点是否在线 -func (node *MasterNode) IsActive() bool { - return true -} - -// Kill 结束aria2请求 -func (node *MasterNode) Kill() { - if node.aria2RPC.Caller != nil { - node.aria2RPC.Caller.Close() - } -} - -// GetAria2Instance 获取主机Aria2实例 -func (node *MasterNode) GetAria2Instance() common.Aria2 { - node.lock.RLock() - - if !node.Model.Aria2Enabled { - node.lock.RUnlock() - return &common.DummyAria2{} - } - - if !node.aria2RPC.Initialized { - node.lock.RUnlock() - node.aria2RPC.Init() - return &common.DummyAria2{} - } - - defer node.lock.RUnlock() - return &node.aria2RPC -} - -func (node *MasterNode) IsMater() bool { - return true -} - -func (node *MasterNode) DBModel() *model.Node { - node.lock.RLock() - defer node.lock.RUnlock() - - return node.Model -} - -func (r *rpcService) Init() error { - r.parent.lock.Lock() - defer r.parent.lock.Unlock() - r.Initialized = false - - // 客户端已存在,则关闭先前连接 - if r.Caller != nil { - r.Caller.Close() - } - - // 解析RPC服务地址 - server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server) - if err != nil { - util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err) - return err - } - server.Path = "/jsonrpc" - - // 加载自定义下载配置 - var globalOptions map[string]interface{} - if r.parent.Model.Aria2OptionsSerialized.Options != "" { - err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) - if err != nil { - util.Log().Warning("Failed to parse aria2 options: %s", err) - return err - } - } - - r.options = &clientOptions{ - Options: globalOptions, - } - timeout := r.parent.Model.Aria2OptionsSerialized.Timeout - caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ) - - r.Caller = caller - r.Initialized = err == nil - return err -} - -func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { - r.parent.lock.RLock() - // 生成存储路径 - guid, _ := uuid.NewV4() - path := filepath.Join( - r.parent.Model.Aria2OptionsSerialized.TempPath, - "aria2", - guid.String(), - ) - r.parent.lock.RUnlock() - - // 创建下载任务 - options := map[string]interface{}{ - "dir": path, - } - for k, v := range r.options.Options { - options[k] = v - } - for k, v := range groupOptions { - options[k] = v - } - - gid, err := r.Caller.AddURI(task.Source, options) - if err != nil || gid == "" { - return "", err - } - - return gid, nil -} - -func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { - res, err := r.Caller.TellStatus(task.GID) - if err != nil { - // 失败后重试 - util.Log().Debug("Failed to get download task status, please retry later: %s", err) - time.Sleep(r.retryDuration) - res, err = r.Caller.TellStatus(task.GID) - } - - return res, err -} - -func (r *rpcService) Cancel(task *model.Download) error { - // 取消下载任务 - _, err := r.Caller.Remove(task.GID) - if err != nil { - util.Log().Warning("Failed to cancel task %q: %s", task.GID, err) - } - - return err -} - -func (r *rpcService) Select(task *model.Download, files []int) error { - var selected = make([]string, len(files)) - for i := 0; i < len(files); i++ { - selected[i] = strconv.Itoa(files[i]) - } - _, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) - return err -} - -func (r *rpcService) GetConfig() model.Aria2Option { - r.parent.lock.RLock() - defer r.parent.lock.RUnlock() - - return r.parent.Model.Aria2OptionsSerialized -} - -func (s *rpcService) DeleteTempFile(task *model.Download) error { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - // 避免被aria2占用,异步执行删除 - go func(d time.Duration, src string) { - time.Sleep(d) - err := os.RemoveAll(src) - if err != nil { - util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err) - } - }(s.deletePaddingDuration, task.Parent) - - return nil -} diff --git a/pkg/cluster/master_test.go b/pkg/cluster/master_test.go deleted file mode 100644 index 7ff07ac9..00000000 --- a/pkg/cluster/master_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package cluster - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" - "os" - "testing" - "time" -) - -func TestMasterNode_Init(t *testing.T) { - a := assert.New(t) - m := &MasterNode{} - m.Init(&model.Node{Status: model.NodeSuspend}) - a.Equal(model.NodeSuspend, m.DBModel().Status) - m.Init(&model.Node{Aria2Enabled: true}) -} - -func TestMasterNode_DummyMethods(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - m.Model.ID = 5 - a.Equal(m.Model.ID, m.ID()) - - res, err := m.Ping(&serializer.NodePingReq{}) - a.NoError(err) - a.NotNil(res) - - a.True(m.IsActive()) - a.True(m.IsMater()) - - m.SubscribeStatusChange(func(isActive bool, id uint) {}) -} - -func TestMasterNode_IsFeatureEnabled(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - a.False(m.IsFeatureEnabled("aria2")) - a.False(m.IsFeatureEnabled("random")) - m.Model.Aria2Enabled = true - a.True(m.IsFeatureEnabled("aria2")) -} - -func TestMasterNode_AuthInstance(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - a.NotNil(m.MasterAuthInstance()) - a.NotNil(m.SlaveAuthInstance()) -} - -func TestMasterNode_Kill(t *testing.T) { - m := &MasterNode{ - Model: &model.Node{}, - } - - m.Kill() - - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - m.Kill() -} - -func TestMasterNode_GetAria2Instance(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - aria2RPC: rpcService{}, - } - - m.aria2RPC.parent = m - - a.NotNil(m.GetAria2Instance()) - m.Model.Aria2Enabled = true - a.NotNil(m.GetAria2Instance()) - m.aria2RPC.Initialized = true - a.NotNil(m.GetAria2Instance()) -} - -func TestRpcService_Init(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{ - Aria2OptionsSerialized: model.Aria2Option{ - Options: "{", - }, - }, - aria2RPC: rpcService{}, - } - m.aria2RPC.parent = m - - // failed to decode address - { - m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f}) - a.Error(m.aria2RPC.Init()) - } - - // failed to decode options - { - m.Model.Aria2OptionsSerialized.Server = "" - a.Error(m.aria2RPC.Init()) - } - - // failed to initialized - { - m.Model.Aria2OptionsSerialized.Server = "" - m.Model.Aria2OptionsSerialized.Options = "{}" - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - a.Error(m.aria2RPC.Init()) - a.False(m.aria2RPC.Initialized) - } -} - -func getTestRPCNode() *MasterNode { - m := &MasterNode{ - Model: &model.Node{ - Aria2OptionsSerialized: model.Aria2Option{}, - }, - aria2RPC: rpcService{ - options: &clientOptions{ - Options: map[string]interface{}{"1": "1"}, - }, - }, - } - m.aria2RPC.parent = m - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - return m -} - -func TestRpcService_CreateTask(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"}) - a.Error(err) - a.Empty(res) -} - -func TestRpcService_Status(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - res, err := m.aria2RPC.Status(&model.Download{}) - a.Error(err) - a.Empty(res) -} - -func TestRpcService_Cancel(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - a.Error(m.aria2RPC.Cancel(&model.Download{})) -} - -func TestRpcService_Select(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - a.NotNil(m.aria2RPC.GetConfig()) - a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3})) -} - -func TestRpcService_DeleteTempFile(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - fdName := "TestRpcService_DeleteTempFile" - a.NoError(os.Mkdir(fdName, 0644)) - - a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName})) - time.Sleep(500 * time.Millisecond) - a.False(util.Exists(fdName)) -} diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go index 745dd259..584e85c7 100644 --- a/pkg/cluster/node.go +++ b/pkg/cluster/node.go @@ -1,60 +1,413 @@ package cluster import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/qbittorrent" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "strconv" ) -type Node interface { - // Init a node from database model - Init(node *model.Node) +type ( + Node interface { + fs.StatelessUploadManager + ID() int + Name() string + IsMaster() bool + // CreateTask creates a task on the node. It does not have effect on master node. + CreateTask(ctx context.Context, taskType string, state string) (int, error) + // GetTask returns the task summary of the task with the given id. + GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) + // CleanupFolders cleans up the given folders on the node. + CleanupFolders(ctx context.Context, folders ...string) error + // AuthInstance returns the auth instance for the node. + AuthInstance() auth.Auth + // CreateDownloader creates a downloader instance from the node for remote download tasks. + CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) + // Settings returns the settings of the node. + Settings(ctx context.Context) *types.NodeSetting + } + + // Request body for creating tasks on slave node + CreateSlaveTask struct { + Type string `json:"type"` + State string `json:"state"` + } + + // Request body for cleaning up folders on slave node + FolderCleanup struct { + Path []string `json:"path" binding:"required"` + } + + SlaveTaskSummary struct { + Status task.Status `json:"status"` + Error string `json:"error"` + PrivateState string `json:"private_state"` + Progress queue.Progresses `json:"progress,omitempty"` + } + + MasterSiteUrlCtx struct{} + MasterSiteVersionCtx struct{} + MasterSiteIDCtx struct{} + SlaveNodeIDCtx struct{} + masterNode struct { + nodeBase + client request.Client + } +) + +func newNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) Node { + if model.Type == node.TypeMaster { + return newMasterNode(model, config, settings) + } + return newSlaveNode(ctx, model, config, settings) +} + +func newMasterNode(model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *masterNode { + n := &masterNode{ + nodeBase: nodeBase{ + model: model, + }, + } + + if config.System().Mode == conf.SlaveMode { + n.client = request.NewClient(config, + request.WithCorrelationID(), + request.WithCredential(auth.HMACAuth{ + []byte(config.Slave().Secret), + }, int64(config.Slave().SignatureTTL)), + ) + } + + return n +} + +func (b *masterNode) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) { + reqBody, err := json.Marshal(args) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "prepare") + resp, err := b.client.Request( + "PUT", + requestDst.String(), + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithSlaveMeta(NodeIdFromContext(ctx)), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + // 处理列取结果 + if resp.Code != 0 { + return nil, serializer.NewErrorFromResponse(resp) + } + + uploadRequest := &fs.StatelessPrepareUploadResponse{} + resp.GobDecode(uploadRequest) + + return uploadRequest, nil +} + +func (b *masterNode) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error { + reqBody, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "complete") + resp, err := b.client.Request( + "POST", + requestDst.String(), + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithSlaveMeta(NodeIdFromContext(ctx)), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (b *masterNode) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error { + reqBody, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "failed") + resp, err := b.client.Request( + "POST", + requestDst.String(), + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithSlaveMeta(NodeIdFromContext(ctx)), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (b *masterNode) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error { + reqBody, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "create") + resp, err := b.client.Request( + "POST", + requestDst.String(), + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithSlaveMeta(NodeIdFromContext(ctx)), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil + +} + +func (b *masterNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) { + return NewDownloader(ctx, c, settings, b.Settings(ctx)) +} + +// NewDownloader creates a new downloader instance from the node for remote download tasks. +func NewDownloader(ctx context.Context, c request.Client, settings setting.Provider, options *types.NodeSetting) (downloader.Downloader, error) { + if options.Provider == types.DownloaderProviderQBittorrent { + return qbittorrent.NewClient(logging.FromContext(ctx), c, settings, options.QBittorrentSetting) + } else if options.Provider == types.DownloaderProviderAria2 { + return aria2.New(logging.FromContext(ctx), settings, options.Aria2Setting), nil + } else if options.Provider == "" { + return nil, errors.New("downloader not configured for this node") + } else { + return nil, errors.New("unknown downloader provider") + } +} + +type slaveNode struct { + nodeBase + client request.Client +} + +func newSlaveNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *slaveNode { + siteBasic := settings.SiteBasic(ctx) + return &slaveNode{ + nodeBase: nodeBase{ + model: model, + }, + client: request.NewClient(config, + request.WithCorrelationID(), + request.WithSlaveMeta(model.ID), + request.WithMasterMeta(siteBasic.ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()), + request.WithCredential(auth.HMACAuth{[]byte(model.SlaveKey)}, int64(settings.SlaveRequestSignTTL(ctx))), + request.WithEndpoint(model.Server)), + } +} + +func (n *slaveNode) CreateTask(ctx context.Context, taskType string, state string) (int, error) { + reqBody, err := json.Marshal(&CreateSlaveTask{ + Type: taskType, + State: state, + }) + if err != nil { + return 0, fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := n.client.Request( + "PUT", + constants.APIPrefixSlave+"/task", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return 0, err + } + + // 处理列取结果 + if resp.Code != 0 { + return 0, serializer.NewErrorFromResponse(resp) + } + + taskId := 0 + if resp.GobDecode(&taskId); taskId > 0 { + return taskId, nil + } + + return 0, fmt.Errorf("unexpected response data: %v", resp.Data) +} + +func (n *slaveNode) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) { + resp, err := n.client.Request( + "GET", + routes.SlaveGetTaskRoute(id, clearOnComplete), + nil, + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + // 处理列取结果 + if resp.Code != 0 { + return nil, serializer.NewErrorFromResponse(resp) + } + + summary := &SlaveTaskSummary{} + resp.GobDecode(summary) + + return summary, nil +} + +func (b *slaveNode) CleanupFolders(ctx context.Context, folders ...string) error { + args := &FolderCleanup{ + Path: folders, + } + reqBody, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := b.client.Request( + "POST", + constants.APIPrefixSlave+"/task/cleanup", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (b *slaveNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) { + return slave.NewSlaveDownloader(b.client, b.Settings(ctx)), nil +} + +type nodeBase struct { + model *ent.Node +} + +func (b *nodeBase) ID() int { + return b.model.ID +} + +func (b *nodeBase) Name() string { + return b.model.Name +} + +func (b *nodeBase) IsMaster() bool { + return b.model.Type == node.TypeMaster +} - // Check if given feature is enabled - IsFeatureEnabled(feature string) bool +func (b *nodeBase) CreateTask(ctx context.Context, taskType string, state string) (int, error) { + return 0, errors.New("not implemented") +} - // Subscribe node status change to a callback function - SubscribeStatusChange(callback func(isActive bool, id uint)) +func (b *nodeBase) AuthInstance() auth.Auth { + return auth.HMACAuth{[]byte(b.model.SlaveKey)} +} - // Ping the node - Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) +func (b *nodeBase) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) { + return nil, errors.New("not implemented") +} - // Returns if the node is active - IsActive() bool +func (b *nodeBase) CleanupFolders(ctx context.Context, folders ...string) error { + return errors.New("not implemented") +} + +func (b *nodeBase) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) { + return nil, errors.New("not implemented") +} - // Get instances for aria2 calls - GetAria2Instance() common.Aria2 +func (b *nodeBase) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error { + return errors.New("not implemented") +} - // Returns unique id of this node - ID() uint +func (b *nodeBase) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error { + return errors.New("not implemented") +} - // Kill node and recycle resources - Kill() +func (b *nodeBase) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error { + return errors.New("not implemented") +} - // Returns if current node is master node - IsMater() bool +func (b *nodeBase) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) { + return nil, errors.New("not implemented") +} - // Get auth instance used to check RPC call from slave to master - MasterAuthInstance() auth.Auth +func (b *nodeBase) Settings(ctx context.Context) *types.NodeSetting { + return b.model.Settings +} - // Get auth instance used to check RPC call from master to slave - SlaveAuthInstance() auth.Auth +func NodeIdFromContext(ctx context.Context) int { + nodeIdStr, ok := ctx.Value(SlaveNodeIDCtx{}).(string) + if !ok { + return 0 + } - // Get node DB model - DBModel() *model.Node + nodeId, _ := strconv.Atoi(nodeIdStr) + return nodeId } -// Create new node from DB model -func NewNodeFromDBModel(node *model.Node) Node { - switch node.Type { - case model.SlaveNodeType: - slave := &SlaveNode{} - slave.Init(node) - return slave - default: - master := &MasterNode{} - master.Init(node) - return master +func MasterSiteUrlFromContext(ctx context.Context) string { + if u, ok := ctx.Value(MasterSiteUrlCtx{}).(string); ok { + return u } + + return "" } diff --git a/pkg/cluster/node_test.go b/pkg/cluster/node_test.go deleted file mode 100644 index d817425a..00000000 --- a/pkg/cluster/node_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package cluster - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewNodeFromDBModel(t *testing.T) { - a := assert.New(t) - a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{ - Type: model.SlaveNodeType, - })) - a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{ - Type: model.MasterNodeType, - })) -} diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index d6704b60..dfbc75de 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -1,190 +1,203 @@ package cluster import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "context" + "fmt" "sync" -) - -var Default *NodePool - -// 需要分类的节点组 -var featureGroup = []string{"aria2"} -// Pool 节点池 -type Pool interface { - // Returns active node selected by given feature and load balancer - BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) - - // Returns node by ID - GetNodeByID(id uint) Node - - // Add given node into pool. If node existed, refresh node. - Add(node *model.Node) + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" +) - // Delete and kill node from pool by given node id - Delete(id uint) +type NodePool interface { + // Upsert updates or inserts a node into the pool. + Upsert(ctx context.Context, node *ent.Node) + // Get returns a node with the given capability and preferred node id. `allowed` is a list of allowed node ids. + // If `allowed` is empty, all nodes with the capability are considered. + Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) } -// NodePool 通用节点池 -type NodePool struct { - active map[uint]Node - inactive map[uint]Node +type ( + weightedNodePool struct { + lock sync.RWMutex - featureMap map[string][]Node + conf conf.ConfigProvider + settings setting.Provider - lock sync.RWMutex -} + nodes map[types.NodeCapability][]*nodeItem + } -// Init 初始化从机节点池 -func Init() { - Default = &NodePool{} - Default.Init() - if err := Default.initFromDB(); err != nil { - util.Log().Warning("Failed to initialize node pool: %s", err) + nodeItem struct { + node Node + weight int + current int } -} +) -func (pool *NodePool) Init() { - pool.lock.Lock() - defer pool.lock.Unlock() +var ( + ErrNoAvailableNode = fmt.Errorf("no available node found") - pool.featureMap = make(map[string][]Node) - pool.active = make(map[uint]Node) - pool.inactive = make(map[uint]Node) -} + supportedCapabilities = []types.NodeCapability{ + types.NodeCapabilityNone, + types.NodeCapabilityCreateArchive, + types.NodeCapabilityExtractArchive, + types.NodeCapabilityRemoteDownload, + } +) -func (pool *NodePool) buildIndexMap() { - pool.lock.Lock() - for _, feature := range featureGroup { - pool.featureMap[feature] = make([]Node, 0) +func NewNodePool(ctx context.Context, l logging.Logger, config conf.ConfigProvider, settings setting.Provider, + client inventory.NodeClient) (NodePool, error) { + nodes, err := client.ListActiveNodes(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to list active nodes: %w", err) } - for _, v := range pool.active { - for _, feature := range featureGroup { - if v.IsFeatureEnabled(feature) { - pool.featureMap[feature] = append(pool.featureMap[feature], v) + pool := &weightedNodePool{ + nodes: make(map[types.NodeCapability][]*nodeItem), + conf: config, + settings: settings, + } + for _, node := range nodes { + for _, capability := range supportedCapabilities { + // If current capability is enabled, add it to pool slot. + if capability == types.NodeCapabilityNone || + (node.Capabilities != nil && node.Capabilities.Enabled(int(capability))) { + if _, ok := pool.nodes[capability]; !ok { + pool.nodes[capability] = make([]*nodeItem, 0) + } + + l.Debug("Add node %q to capability slot %d with weight %d", node.Name, capability, node.Weight) + pool.nodes[capability] = append(pool.nodes[capability], &nodeItem{ + node: newNode(ctx, node, config, settings), + weight: node.Weight, + current: 0, + }) } } } - pool.lock.Unlock() + + return pool, nil } -func (pool *NodePool) GetNodeByID(id uint) Node { - pool.lock.RLock() - defer pool.lock.RUnlock() +func (p *weightedNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) { + l := logging.FromContext(ctx) + p.lock.RLock() + defer p.lock.RUnlock() - if node, ok := pool.active[id]; ok { - return node + nodes, ok := p.nodes[capability] + if !ok || len(nodes) == 0 { + return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode) } - return pool.inactive[id] -} + var selected *nodeItem -func (pool *NodePool) nodeStatusChange(isActive bool, id uint) { - util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive) - var node Node - pool.lock.Lock() - if n, ok := pool.inactive[id]; ok { - node = n - delete(pool.inactive, id) - } else { - node = pool.active[id] - delete(pool.active, id) - } + if preferred > 0 { + // First try to find the preferred node. + for _, n := range nodes { + if n.node.ID() == preferred { + selected = n + break + } + } - if isActive { - pool.active[id] = node - } else { - pool.inactive[id] = node + if selected == nil { + l.Debug("Preferred node %d not found, fallback to select a node with the least current weight", preferred) + } } - pool.lock.Unlock() - pool.buildIndexMap() -} + if selected == nil { + // If no preferred one, or the preferred one is not available, select a node with the least current weight. -func (pool *NodePool) initFromDB() error { - nodes, err := model.GetNodesByStatus(model.NodeActive) - if err != nil { - return err - } + // Total weight of all items. + var total int - pool.lock.Lock() - for i := 0; i < len(nodes); i++ { - pool.add(&nodes[i]) - } - pool.lock.Unlock() + // Loop through the list of items and add the item's weight to the current weight. + // Also increment the total weight counter. + var maxNode *nodeItem + for _, item := range nodes { + item.current += max(1, item.weight) + total += max(1, item.weight) - pool.buildIndexMap() - return nil -} + // Select the item with max weight. + if maxNode == nil || item.current > maxNode.current { + maxNode = item + } + } -func (pool *NodePool) add(node *model.Node) { - newNode := NewNodeFromDBModel(node) - if newNode.IsActive() { - pool.active[node.ID] = newNode - } else { - pool.inactive[node.ID] = newNode - } + // Select the item with the max weight. + selected = maxNode + if selected == nil { + return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode) + } - // 订阅节点状态变更 - newNode.SubscribeStatusChange(func(isActive bool, id uint) { - pool.nodeStatusChange(isActive, id) - }) -} + l.Debug("Selected node %q with weight=%d, current=%d, total=%d", selected.node.Name(), selected.weight, maxNode.current, total) -func (pool *NodePool) Add(node *model.Node) { - pool.lock.Lock() - defer pool.buildIndexMap() - defer pool.lock.Unlock() - - var ( - old Node - ok bool - ) - if old, ok = pool.active[node.ID]; !ok { - old, ok = pool.inactive[node.ID] - } - if old != nil { - go old.Init(node) - return + // Reduce the current weight of the selected item by the total weight. + maxNode.current -= total } - pool.add(node) + return selected.node, nil } -func (pool *NodePool) Delete(id uint) { - pool.lock.Lock() - defer pool.buildIndexMap() - defer pool.lock.Unlock() - - if node, ok := pool.active[id]; ok { - node.Kill() - delete(pool.active, id) - return - } +func (p *weightedNodePool) Upsert(ctx context.Context, n *ent.Node) { + p.lock.Lock() + defer p.lock.Unlock() + + for _, capability := range supportedCapabilities { + _, index, found := lo.FindIndexOf(p.nodes[capability], func(i *nodeItem) bool { + return i.node.ID() == n.ID + }) + if capability == types.NodeCapabilityNone || + (n.Capabilities != nil && n.Capabilities.Enabled(int(capability))) { + if n.Status != node.StatusActive && found { + // Remove inactive node + p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...) + continue + } - if node, ok := pool.inactive[id]; ok { - node.Kill() - delete(pool.inactive, id) - return + if found { + p.nodes[capability][index].node = newNode(ctx, n, p.conf, p.settings) + } else { + p.nodes[capability] = append(p.nodes[capability], &nodeItem{ + node: newNode(ctx, n, p.conf, p.settings), + weight: n.Weight, + current: 0, + }) + } + } else if found { + // Capability changed, remove the old node. + p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...) + } } - } -// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点 -func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) { - pool.lock.RLock() - defer pool.lock.RUnlock() - if nodes, ok := pool.featureMap[feature]; ok { - err, res := lb.NextPeer(nodes) - if err == nil { - return nil, res.(Node) - } +type slaveDummyNodePool struct { + conf conf.ConfigProvider + settings setting.Provider + masterNode Node +} - return err, nil +func NewSlaveDummyNodePool(ctx context.Context, config conf.ConfigProvider, settings setting.Provider) NodePool { + return &slaveDummyNodePool{ + conf: config, + settings: settings, + masterNode: newNode(ctx, &ent.Node{ + ID: 0, + Name: "Master", + Type: node.TypeMaster, + }, config, settings), } +} + +func (s *slaveDummyNodePool) Upsert(ctx context.Context, node *ent.Node) { +} - return ErrFeatureNotExist, nil +func (s *slaveDummyNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) { + return s.masterNode, nil } diff --git a/pkg/cluster/pool_test.go b/pkg/cluster/pool_test.go deleted file mode 100644 index dde34558..00000000 --- a/pkg/cluster/pool_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package cluster - -import ( - "database/sql" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInitFailed(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - Init() - a.NoError(mock.ExpectationsWereMet()) -} - -func TestInitSuccess(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType)) - Init() - a.NoError(mock.ExpectationsWereMet()) -} - -func TestNodePool_GetNodeByID(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - mockNode := &nodeMock{} - - // inactive - { - p.inactive[1] = mockNode - a.Equal(mockNode, p.GetNodeByID(1)) - } - - // active - { - delete(p.inactive, 1) - p.active[1] = mockNode - a.Equal(mockNode, p.GetNodeByID(1)) - } -} - -func TestNodePool_NodeStatusChange(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - n := &MasterNode{Model: &model.Node{}} - p.Init() - p.inactive[1] = n - - p.nodeStatusChange(true, 1) - a.Len(p.inactive, 0) - a.Equal(n, p.active[1]) - - p.nodeStatusChange(false, 1) - a.Len(p.active, 0) - a.Equal(n, p.inactive[1]) - - p.nodeStatusChange(false, 1) - a.Len(p.active, 0) - a.Equal(n, p.inactive[1]) -} - -func TestNodePool_Add(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // new node - { - p.Add(&model.Node{}) - a.Len(p.active, 1) - } - - // old node - { - p.inactive[0] = p.active[0] - delete(p.active, 0) - p.Add(&model.Node{}) - a.Len(p.active, 0) - a.Len(p.inactive, 1) - } -} - -func TestNodePool_Delete(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // active - { - mockNode := &nodeMock{} - mockNode.On("Kill") - p.active[0] = mockNode - p.Delete(0) - a.Len(p.active, 0) - a.Len(p.inactive, 0) - mockNode.AssertExpectations(t) - } - - p.Init() - - // inactive - { - mockNode := &nodeMock{} - mockNode.On("Kill") - p.inactive[0] = mockNode - p.Delete(0) - a.Len(p.active, 0) - a.Len(p.inactive, 0) - mockNode.AssertExpectations(t) - } -} - -func TestNodePool_BalanceNodeByFeature(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // success - { - p.featureMap["test"] = []Node{&MasterNode{}} - err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) - a.NoError(err) - a.Equal(p.featureMap["test"][0], res) - } - - // NoNodes - { - p.featureMap["test"] = []Node{} - err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) - a.Error(err) - a.Nil(res) - } - - // No match feature - { - err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin")) - a.Error(err) - a.Nil(res) - } -} diff --git a/pkg/cluster/routes/routes.go b/pkg/cluster/routes/routes.go new file mode 100644 index 00000000..d4bdf0a8 --- /dev/null +++ b/pkg/cluster/routes/routes.go @@ -0,0 +1,207 @@ +package routes + +import ( + "encoding/base64" + "fmt" + "net/url" + "path" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" +) + +const ( + IsDownloadQuery = "download" + IsThumbQuery = "thumb" + SlaveClearTaskRegistryQuery = "deleteOnComplete" +) + +var ( + masterPing *url.URL + masterUserActivate *url.URL + masterUserReset *url.URL + masterHome *url.URL +) + +func init() { + masterPing, _ = url.Parse(constants.APIPrefix + "/site/ping") + masterUserActivate, _ = url.Parse("/session/activate") + masterUserReset, _ = url.Parse("/session/reset") +} + +func FrontendHomeUrl(base *url.URL, path string) *url.URL { + route, _ := url.Parse(fmt.Sprintf("/home")) + q := route.Query() + q.Set("path", path) + route.RawQuery = q.Encode() + + return base.ResolveReference(route) +} + +func MasterPingUrl(base *url.URL) *url.URL { + return base.ResolveReference(masterPing) +} + +func MasterSlaveCallbackUrl(base *url.URL, driver, id, secret string) *url.URL { + apiBaseURI, _ := url.Parse(path.Join(constants.APIPrefix+"/callback", driver, id, secret)) + return base.ResolveReference(apiBaseURI) +} + +func MasterUserActivateAPIUrl(base *url.URL, uid string) *url.URL { + route, _ := url.Parse(constants.APIPrefix + "/user/activate/" + uid) + return base.ResolveReference(route) +} + +func MasterUserActivateUrl(base *url.URL) *url.URL { + return base.ResolveReference(masterUserActivate) +} + +func MasterUserResetUrl(base *url.URL) *url.URL { + return base.ResolveReference(masterUserReset) +} + +func MasterShareUrl(base *url.URL, id, password string) *url.URL { + p := "/s/" + id + if password != "" { + p += ("/" + password) + } + route, _ := url.Parse(p) + return base.ResolveReference(route) +} + +func MasterDirectLink(base *url.URL, id, name string) *url.URL { + p := path.Join("/f", id, url.PathEscape(name)) + route, _ := url.Parse(p) + return base.ResolveReference(route) +} + +// MasterShareLongUrl generates a long share URL for redirect. +func MasterShareLongUrl(id, password string) *url.URL { + base, _ := url.Parse("/home") + q := base.Query() + + q.Set("path", fs.NewShareUri(id, password)) + base.RawQuery = q.Encode() + return base +} + +func MasterArchiveDownloadUrl(base *url.URL, sessionID string) *url.URL { + routes, err := url.Parse(path.Join(constants.APIPrefix, "file", "archive", sessionID, "archive.zip")) + if err != nil { + return nil + } + + return base.ResolveReference(routes) +} + +func MasterPolicyOAuthCallback(base *url.URL) *url.URL { + if base.Scheme != "https" { + base.Scheme = "https" + } + routes, err := url.Parse("/admin/policy/oauth") + if err != nil { + return nil + } + return base.ResolveReference(routes) +} + +func MasterGetCredentialUrl(base, key string) *url.URL { + masterBase, err := url.Parse(base) + if err != nil { + return nil + } + + routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "credential", key)) + if err != nil { + return nil + } + + return masterBase.ResolveReference(routes) +} + +func MasterStatelessUrl(base, method string) *url.URL { + masterBase, err := url.Parse(base) + if err != nil { + return nil + } + + routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "statelessUpload", method)) + if err != nil { + return nil + } + + return masterBase.ResolveReference(routes) +} + +func SlaveUploadUrl(base *url.URL, sessionID string) *url.URL { + base.Path = path.Join(base.Path, constants.APIPrefixSlave, "/upload", sessionID) + return base +} + +func MasterFileContentUrl(base *url.URL, entityId, name string, download, thumb bool, speed int64) *url.URL { + name = url.PathEscape(name) + + route, _ := url.Parse(constants.APIPrefix + fmt.Sprintf("/file/content/%s/%d/%s", entityId, speed, name)) + if base != nil { + route = base.ResolveReference(route) + } + + values := url.Values{} + if download { + values.Set(IsDownloadQuery, "true") + } + + if thumb { + values.Set(IsThumbQuery, "true") + } + + route.RawQuery = values.Encode() + return route +} + +func MasterWopiSrc(base *url.URL, sessionId string) *url.URL { + route, _ := url.Parse(constants.APIPrefix + "/file/wopi/" + sessionId) + return base.ResolveReference(route) +} + +func SlaveFileContentUrl(base *url.URL, srcPath, name string, download bool, speed int64, nodeId int) *url.URL { + srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath))) + name = url.PathEscape(name) + route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/content/%d/%s/%d/%s", nodeId, srcPath, speed, name)) + base = base.ResolveReference(route) + + values := url.Values{} + if download { + values.Set(IsDownloadQuery, "true") + } + + base.RawQuery = values.Encode() + return base +} + +func SlaveMediaMetaRoute(src, ext string) string { + src = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(src))) + return fmt.Sprintf("file/meta/%s/%s", src, url.PathEscape(ext)) +} + +func SlaveThumbUrl(base *url.URL, srcPath, ext string) *url.URL { + srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath))) + ext = url.PathEscape(ext) + route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/thumb/%s/%s", srcPath, ext)) + base = base.ResolveReference(route) + return base +} + +func SlaveGetTaskRoute(id int, deleteOnComplete bool) string { + p := constants.APIPrefixSlave + "/task/" + strconv.Itoa(id) + if deleteOnComplete { + p += "?" + SlaveClearTaskRegistryQuery + "=true" + } + return p +} + +func SlavePingRoute(base *url.URL) string { + route, _ := url.Parse(constants.APIPrefixSlave + "/ping") + return base.ResolveReference(route).String() +} diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go deleted file mode 100644 index 94d286bd..00000000 --- a/pkg/cluster/slave.go +++ /dev/null @@ -1,451 +0,0 @@ -package cluster - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io" - "net/url" - "strings" - "sync" - "time" -) - -type SlaveNode struct { - Model *model.Node - Active bool - - caller slaveCaller - callback func(bool, uint) - close chan bool - lock sync.RWMutex -} - -type slaveCaller struct { - parent *SlaveNode - Client request.Client -} - -// Init 初始化节点 -func (node *SlaveNode) Init(nodeModel *model.Node) { - node.lock.Lock() - node.Model = nodeModel - - // Init http request client - var endpoint *url.URL - if serverURL, err := url.Parse(node.Model.Server); err == nil { - var controller *url.URL - controller, _ = url.Parse("/api/v3/slave/") - endpoint = serverURL.ResolveReference(controller) - } - - signTTL := model.GetIntSetting("slave_api_timeout", 60) - node.caller.Client = request.NewClient( - request.WithMasterMeta(), - request.WithTimeout(time.Duration(signTTL)*time.Second), - request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)), - request.WithEndpoint(endpoint.String()), - ) - - node.caller.parent = node - if node.close != nil { - node.lock.Unlock() - node.close <- true - go node.StartPingLoop() - } else { - node.Active = true - node.lock.Unlock() - go node.StartPingLoop() - } -} - -// IsFeatureEnabled 查询节点的某项功能是否启用 -func (node *SlaveNode) IsFeatureEnabled(feature string) bool { - node.lock.RLock() - defer node.lock.RUnlock() - - switch feature { - case "aria2": - return node.Model.Aria2Enabled - default: - return false - } -} - -// SubscribeStatusChange 订阅节点状态更改 -func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) { - node.lock.Lock() - node.callback = callback - node.lock.Unlock() -} - -// Ping 从机节点,返回从机负载 -func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { - node.lock.RLock() - defer node.lock.RUnlock() - - reqBodyEncoded, err := json.Marshal(req) - if err != nil { - return nil, err - } - - bodyReader := strings.NewReader(string(reqBodyEncoded)) - - resp, err := node.caller.Client.Request( - "POST", - "heartbeat", - bodyReader, - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return nil, err - } - - // 处理列取结果 - if resp.Code != 0 { - return nil, serializer.NewErrorFromResponse(resp) - } - - var res serializer.NodePingResp - - if resStr, ok := resp.Data.(string); ok { - err = json.Unmarshal([]byte(resStr), &res) - if err != nil { - return nil, err - } - } - - return &res, nil -} - -// IsActive 返回节点是否在线 -func (node *SlaveNode) IsActive() bool { - node.lock.RLock() - defer node.lock.RUnlock() - - return node.Active -} - -// Kill 结束节点内相关循环 -func (node *SlaveNode) Kill() { - node.lock.RLock() - defer node.lock.RUnlock() - - if node.close != nil { - close(node.close) - } -} - -// GetAria2Instance 获取从机Aria2实例 -func (node *SlaveNode) GetAria2Instance() common.Aria2 { - node.lock.RLock() - defer node.lock.RUnlock() - - if !node.Model.Aria2Enabled { - return &common.DummyAria2{} - } - - return &node.caller -} - -func (node *SlaveNode) ID() uint { - node.lock.RLock() - defer node.lock.RUnlock() - - return node.Model.ID -} - -func (node *SlaveNode) StartPingLoop() { - node.lock.Lock() - node.close = make(chan bool) - node.lock.Unlock() - - tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second - recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second - pingTicker := time.Duration(0) - - util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name) - retry := 0 - recoverMode := false - isFirstLoop := true - -loop: - for { - select { - case <-time.After(pingTicker): - if pingTicker == 0 { - pingTicker = tickDuration - } - - util.Log().Debug("Slave node %q send ping.", node.Model.Name) - res, err := node.Ping(node.getHeartbeatContent(isFirstLoop)) - isFirstLoop = false - - if err != nil { - util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err) - retry++ - if retry >= model.GetIntSetting("slave_node_retry", 3) { - util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name) - node.changeStatus(false) - - if !recoverMode { - // 启动恢复监控循环 - util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name) - pingTicker = recoverDuration - recoverMode = true - } - } - } else { - if recoverMode { - util.Log().Debug("Slave node %q recovered.", node.Model.Name) - pingTicker = tickDuration - recoverMode = false - isFirstLoop = true - } - - util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res) - node.changeStatus(true) - retry = 0 - } - - case <-node.close: - util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name) - break loop - } - } -} - -func (node *SlaveNode) IsMater() bool { - return false -} - -func (node *SlaveNode) MasterAuthInstance() auth.Auth { - node.lock.RLock() - defer node.lock.RUnlock() - - return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} -} - -func (node *SlaveNode) SlaveAuthInstance() auth.Auth { - node.lock.RLock() - defer node.lock.RUnlock() - - return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} -} - -func (node *SlaveNode) DBModel() *model.Node { - node.lock.RLock() - defer node.lock.RUnlock() - - return node.Model -} - -// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave -func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { - return &serializer.NodePingReq{ - SiteURL: model.GetSiteURL().String(), - IsUpdate: isUpdate, - SiteID: model.GetSettingByName("siteID"), - Node: node.Model, - CredentialTTL: model.GetIntSetting("slave_api_timeout", 60), - } -} - -func (node *SlaveNode) changeStatus(isActive bool) { - node.lock.RLock() - id := node.Model.ID - if isActive != node.Active { - node.lock.RUnlock() - node.lock.Lock() - node.Active = isActive - node.lock.Unlock() - node.callback(isActive, id) - } else { - node.lock.RUnlock() - } - -} - -func (s *slaveCaller) Init() error { - return nil -} - -// SendAria2Call send remote aria2 call to slave node -func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) { - reqReader, err := getAria2RequestBody(body) - if err != nil { - return nil, err - } - - return s.Client.Request( - "POST", - "aria2/"+scope, - reqReader, - ).CheckHTTPResponse(200).DecodeResponse() -} - -func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - req := &serializer.SlaveAria2Call{ - Task: task, - GroupOptions: options, - } - - res, err := s.SendAria2Call(req, "task") - if err != nil { - return "", err - } - - if res.Code != 0 { - return "", serializer.NewErrorFromResponse(res) - } - - return res.Data.(string), err -} - -func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - req := &serializer.SlaveAria2Call{ - Task: task, - } - - res, err := s.SendAria2Call(req, "status") - if err != nil { - return rpc.StatusInfo{}, err - } - - if res.Code != 0 { - return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res) - } - - var status rpc.StatusInfo - res.GobDecode(&status) - - return status, err -} - -func (s *slaveCaller) Cancel(task *model.Download) error { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - req := &serializer.SlaveAria2Call{ - Task: task, - } - - res, err := s.SendAria2Call(req, "cancel") - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - return nil -} - -func (s *slaveCaller) Select(task *model.Download, files []int) error { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - req := &serializer.SlaveAria2Call{ - Task: task, - Files: files, - } - - res, err := s.SendAria2Call(req, "select") - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - return nil -} - -func (s *slaveCaller) GetConfig() model.Aria2Option { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - return s.parent.Model.Aria2OptionsSerialized -} - -func (s *slaveCaller) DeleteTempFile(task *model.Download) error { - s.parent.lock.RLock() - defer s.parent.lock.RUnlock() - - req := &serializer.SlaveAria2Call{ - Task: task, - } - - res, err := s.SendAria2Call(req, "delete") - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - return nil -} - -func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { - reqBodyEncoded, err := json.Marshal(body) - if err != nil { - return nil, err - } - - return strings.NewReader(string(reqBodyEncoded)), nil -} - -// RemoteCallback 发送远程存储策略上传回调请求 -func RemoteCallback(url string, body serializer.UploadCallback) error { - callbackBody, err := json.Marshal(struct { - Data serializer.UploadCallback `json:"data"` - }{ - Data: body, - }) - if err != nil { - return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err) - } - - resp := request.GeneralClient.Request( - "POST", - url, - bytes.NewReader(callbackBody), - request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), - request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), - ) - - if resp.Err != nil { - return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err) - } - - // 解析回调服务端响应 - response, err := resp.DecodeResponse() - if err != nil { - msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode) - return serializer.NewError(serializer.CodeCallbackError, msg, err) - } - - if response.Code != 0 { - return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) - } - - return nil -} diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go deleted file mode 100644 index 1b1510f6..00000000 --- a/pkg/cluster/slave_test.go +++ /dev/null @@ -1,559 +0,0 @@ -package cluster - -import ( - "bytes" - "encoding/json" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" -) - -func TestSlaveNode_InitAndKill(t *testing.T) { - a := assert.New(t) - n := &SlaveNode{ - callback: func(b bool, u uint) { - - }, - } - - a.NotPanics(func() { - n.Init(&model.Node{}) - time.Sleep(time.Millisecond * 500) - n.Init(&model.Node{}) - n.Kill() - }) -} - -func TestSlaveNode_DummyMethods(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - m.Model.ID = 5 - a.Equal(m.Model.ID, m.ID()) - a.Equal(m.Model.ID, m.DBModel().ID) - - a.False(m.IsActive()) - a.False(m.IsMater()) - - m.SubscribeStatusChange(func(isActive bool, id uint) {}) -} - -func TestSlaveNode_IsFeatureEnabled(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.False(m.IsFeatureEnabled("aria2")) - a.False(m.IsFeatureEnabled("random")) - m.Model.Aria2Enabled = true - a.True(m.IsFeatureEnabled("aria2")) -} - -func TestSlaveNode_Ping(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - // master return error code - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.Error(err) - a.Nil(res) - a.Equal(1, err.(serializer.AppError).Code) - } - - // return unexpected json - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.Error(err) - a.Nil(res) - } - - // return success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.NoError(err) - a.NotNil(res) - } -} - -func TestSlaveNode_GetAria2Instance(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.NotNil(m.GetAria2Instance()) - m.Model.Aria2Enabled = true - a.NotNil(m.GetAria2Instance()) - a.NotNil(m.GetAria2Instance()) -} - -func TestSlaveNode_StartPingLoop(t *testing.T) { - callbackCount := 0 - finishedChan := make(chan struct{}) - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m := &SlaveNode{ - Active: true, - Model: &model.Node{}, - callback: func(b bool, u uint) { - callbackCount++ - if callbackCount == 2 { - close(finishedChan) - } - if callbackCount == 1 { - mockRequest.AssertExpectations(t) - mockRequest = requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), - }, - }) - } - }, - } - cache.Set("setting_slave_ping_interval", "0", 0) - cache.Set("setting_slave_recover_interval", "0", 0) - cache.Set("setting_slave_node_retry", "1", 0) - - m.caller.Client = &mockRequest - go func() { - select { - case <-finishedChan: - m.Kill() - } - }() - m.StartPingLoop() - mockRequest.AssertExpectations(t) -} - -func TestSlaveNode_AuthInstance(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.NotNil(m.MasterAuthInstance()) - a.NotNil(m.SlaveAuthInstance()) -} - -func TestSlaveNode_ChangeStatus(t *testing.T) { - a := assert.New(t) - isActive := false - m := &SlaveNode{ - Model: &model.Node{}, - callback: func(b bool, u uint) { - isActive = b - }, - } - - a.NotPanics(func() { - m.changeStatus(false) - }) - m.changeStatus(true) - a.True(isActive) -} - -func getTestRPCNodeSlave() *SlaveNode { - m := &SlaveNode{ - Model: &model.Node{}, - } - m.caller.parent = m - return m -} - -func TestSlaveCaller_CreateTask(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Empty(res) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Empty(res) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Equal("res", res) - a.NoError(err) - } -} - -func TestSlaveCaller_Status(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.NoError(err) - } -} - -func TestSlaveCaller_Cancel(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.NoError(err) - } -} - -func TestSlaveCaller_Select(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - m.caller.Init() - m.caller.GetConfig() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.NoError(err) - } -} - -func TestSlaveCaller_DeleteTempFile(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - m.caller.Init() - m.caller.GetConfig() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.NoError(err) - } -} - -func TestRemoteCallback(t *testing.T) { - asserts := assert.New(t) - - // 回调成功 - { - clientMock := requestmock.RequestMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 0}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.NoError(resp) - clientMock.AssertExpectations(t) - } - - // 服务端返回业务错误 - { - clientMock := requestmock.RequestMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 401}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.EqualValues(401, resp.(serializer.AppError).Code) - clientMock.AssertExpectations(t) - } - - // 无法解析回调响应 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // HTTP状态码非200 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // 无法发起回调 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } -} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index b0a4ea4d..78e770b7 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -1,145 +1,146 @@ package conf import ( - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/go-ini/ini" "github.com/go-playground/validator/v10" + "os" + "strings" ) -// database 数据库 -type database struct { - Type string - User string - Password string - Host string - Name string - TablePrefix string - DBFile string - Port int - Charset string - UnixSocket bool -} - -// system 系统通用配置 -type system struct { - Mode string `validate:"eq=master|eq=slave"` - Listen string `validate:"required"` - Debug bool - SessionSecret string - HashIDSalt string - GracePeriod int `validate:"gte=0"` - ProxyHeader string `validate:"required_with=Listen"` -} - -type ssl struct { - CertPath string `validate:"omitempty,required"` - KeyPath string `validate:"omitempty,required"` - Listen string `validate:"required"` -} - -type unix struct { - Listen string - Perm uint32 -} - -// slave 作为slave存储端配置 -type slave struct { - Secret string `validate:"omitempty,gte=64"` - CallbackTimeout int `validate:"omitempty,gte=1"` - SignatureTTL int `validate:"omitempty,gte=1"` -} - -// redis 配置 -type redis struct { - Network string - Server string - User string - Password string - DB string -} +const ( + envConfOverrideKey = "CR_CONF_" +) -// 跨域配置 -type cors struct { - AllowOrigins []string - AllowMethods []string - AllowHeaders []string - AllowCredentials bool - ExposeHeaders []string - SameSite string - Secure bool +type ConfigProvider interface { + Database() *Database + System() *System + SSL() *SSL + Unix() *Unix + Slave() *Slave + Redis() *Redis + Cors() *Cors + OptionOverwrite() map[string]any } -var cfg *ini.File - -const defaultConf = `[System] -Debug = false -Mode = master -Listen = :5212 -SessionSecret = {SessionSecret} -HashIDSalt = {HashIDSalt} -` - -// Init 初始化配置文件 -func Init(path string) { - var err error - - if path == "" || !util.Exists(path) { +// NewIniConfigProvider initializes a new Ini config file provider. A default config file +// will be created if the given path does not exist. +func NewIniConfigProvider(configPath string, l logging.Logger) (ConfigProvider, error) { + if configPath == "" || !util.Exists(configPath) { + l.Info("Config file %q not found, creating a new one.", configPath) // 创建初始配置文件 confContent := util.Replace(map[string]string{ "{SessionSecret}": util.RandStringRunes(64), - "{HashIDSalt}": util.RandStringRunes(64), }, defaultConf) - f, err := util.CreatNestedFile(path) + f, err := util.CreatNestedFile(configPath) if err != nil { - util.Log().Panic("Failed to create config file: %s", err) + return nil, fmt.Errorf("failed to create config file: %w", err) } // 写入配置文件 _, err = f.WriteString(confContent) if err != nil { - util.Log().Panic("Failed to write config file: %s", err) + return nil, fmt.Errorf("failed to write config file: %w", err) } f.Close() } - cfg, err = ini.Load(path) + cfg, err := ini.Load(configPath, []byte(getOverrideConfFromEnv(l))) if err != nil { - util.Log().Panic("Failed to parse config file %q: %s", path, err) + return nil, fmt.Errorf("failed to parse config file %q: %w", configPath, err) + } + + provider := &iniConfigProvider{ + database: *DatabaseConfig, + system: *SystemConfig, + ssl: *SSLConfig, + unix: *UnixConfig, + slave: *SlaveConfig, + redis: *RedisConfig, + cors: *CORSConfig, + optionOverwrite: make(map[string]interface{}), } sections := map[string]interface{}{ - "Database": DatabaseConfig, - "System": SystemConfig, - "SSL": SSLConfig, - "UnixSocket": UnixConfig, - "Redis": RedisConfig, - "CORS": CORSConfig, - "Slave": SlaveConfig, + "Database": &provider.database, + "System": &provider.system, + "SSL": &provider.ssl, + "UnixSocket": &provider.unix, + "Redis": &provider.redis, + "CORS": &provider.cors, + "Slave": &provider.slave, } for sectionName, sectionStruct := range sections { - err = mapSection(sectionName, sectionStruct) + err = mapSection(cfg, sectionName, sectionStruct) if err != nil { - util.Log().Panic("Failed to parse config section %q: %s", sectionName, err) + return nil, fmt.Errorf("failed to parse config section %q: %w", sectionName, err) } } // 映射数据库配置覆盖 for _, key := range cfg.Section("OptionOverwrite").Keys() { - OptionOverwrite[key.Name()] = key.Value() + provider.optionOverwrite[key.Name()] = key.Value() } - // 重设log等级 - if !SystemConfig.Debug { - util.Level = util.LevelInformational - util.GloablLogger = nil - util.Log() - } + return provider, nil +} + +type iniConfigProvider struct { + database Database + system System + ssl SSL + unix Unix + slave Slave + redis Redis + cors Cors + optionOverwrite map[string]any +} +func (i *iniConfigProvider) Database() *Database { + return &i.database } +func (i *iniConfigProvider) System() *System { + return &i.system +} + +func (i *iniConfigProvider) SSL() *SSL { + return &i.ssl +} + +func (i *iniConfigProvider) Unix() *Unix { + return &i.unix +} + +func (i *iniConfigProvider) Slave() *Slave { + return &i.slave +} + +func (i *iniConfigProvider) Redis() *Redis { + return &i.redis +} + +func (i *iniConfigProvider) Cors() *Cors { + return &i.cors +} + +func (i *iniConfigProvider) OptionOverwrite() map[string]any { + return i.optionOverwrite +} + +const defaultConf = `[System] +Debug = false +Mode = master +Listen = :5212 +SessionSecret = {SessionSecret} +HashIDSalt = {HashIDSalt} +` + // mapSection 将配置文件的 Section 映射到结构体上 -func mapSection(section string, confStruct interface{}) error { +func mapSection(cfg *ini.File, section string, confStruct interface{}) error { err := cfg.Section(section).MapTo(confStruct) if err != nil { return err @@ -154,3 +155,35 @@ func mapSection(section string, confStruct interface{}) error { return nil } + +func getOverrideConfFromEnv(l logging.Logger) string { + confMaps := make(map[string]map[string]string) + for _, env := range os.Environ() { + if !strings.HasPrefix(env, envConfOverrideKey) { + continue + } + + // split by key=value and get key + kv := strings.SplitN(env, "=", 2) + configKey := strings.TrimPrefix(kv[0], envConfOverrideKey) + configValue := kv[1] + sectionKey := strings.SplitN(configKey, ".", 2) + if confMaps[sectionKey[0]] == nil { + confMaps[sectionKey[0]] = make(map[string]string) + } + + confMaps[sectionKey[0]][sectionKey[1]] = configValue + l.Info("Override config %q = %q", configKey, configValue) + } + + // generate ini content + var sb strings.Builder + for section, kvs := range confMaps { + sb.WriteString(fmt.Sprintf("[%s]\n", section)) + for k, v := range kvs { + sb.WriteString(fmt.Sprintf("%s = %s\n", k, v)) + } + } + + return sb.String() +} diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index 6d186ed4..e9bc646d 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -1,12 +1,11 @@ package conf import ( + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/stretchr/testify/assert" "io/ioutil" "os" "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" ) // 测试Init日志路径错误 @@ -15,10 +14,10 @@ func TestInitPanic(t *testing.T) { // 日志路径不存在时 asserts.NotPanics(func() { - Init("not/exist/path/conf.ini") + Init("not/exist/path") }) - asserts.True(util.Exists("not/exist/path/conf.ini")) + asserts.True(util.Exists("conf.ini")) } @@ -56,11 +55,7 @@ User = root Password = root Host = 127.0.0.1:3306 Name = v3 -TablePrefix = v3_ - -[OptionOverwrite] -key=value -` +TablePrefix = v3_` err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) defer func() { err = os.Remove("testConf.ini") }() if err != nil { @@ -69,7 +64,6 @@ key=value asserts.NotPanics(func() { Init("testConf.ini") }) - asserts.Equal(OptionOverwrite["key"], "value") } func TestMapSection(t *testing.T) { diff --git a/pkg/conf/types.go b/pkg/conf/types.go new file mode 100644 index 00000000..32e2e4f3 --- /dev/null +++ b/pkg/conf/types.go @@ -0,0 +1,138 @@ +package conf + +import "github.com/cloudreve/Cloudreve/v4/pkg/util" + +type DBType string + +var ( + SQLiteDB DBType = "sqlite" + SQLite3DB DBType = "sqlite3" + MySqlDB DBType = "mysql" + MsSqlDB DBType = "mssql" + PostgresDB DBType = "postgres" +) + +// Database 数据库 +type Database struct { + Type DBType + User string + Password string + Host string + Name string + TablePrefix string + DBFile string + Port int + Charset string + UnixSocket bool +} + +type SysMode string + +var ( + MasterMode SysMode = "master" + SlaveMode SysMode = "slave" +) + +// System 系统通用配置 +type System struct { + Mode SysMode `validate:"eq=master|eq=slave"` + Listen string `validate:"required"` + Debug bool + SessionSecret string + HashIDSalt string // deprecated + GracePeriod int `validate:"gte=0"` + ProxyHeader string `validate:"required_with=Listen"` + LogLevel string `validate:"oneof=debug info warning error"` +} + +type SSL struct { + CertPath string `validate:"omitempty,required"` + KeyPath string `validate:"omitempty,required"` + Listen string `validate:"required"` +} + +type Unix struct { + Listen string + Perm uint32 +} + +// Slave 作为slave存储端配置 +type Slave struct { + Secret string `validate:"omitempty,gte=64"` + CallbackTimeout int `validate:"omitempty,gte=1"` + SignatureTTL int `validate:"omitempty,gte=1"` +} + +// Redis 配置 +type Redis struct { + Network string + Server string + User string + Password string + DB string +} + +// 跨域配置 +type Cors struct { + AllowOrigins []string + AllowMethods []string + AllowHeaders []string + AllowCredentials bool + ExposeHeaders []string + SameSite string + Secure bool +} + +// RedisConfig Redis服务器配置 +var RedisConfig = &Redis{ + Network: "tcp", + Server: "", + Password: "", + DB: "0", +} + +// DatabaseConfig 数据库配置 +var DatabaseConfig = &Database{ + Charset: "utf8mb4", + DBFile: util.DataPath("cloudreve.db"), + Port: 3306, + UnixSocket: false, +} + +// SystemConfig 系统公用配置 +var SystemConfig = &System{ + Debug: false, + Mode: MasterMode, + Listen: ":5212", + ProxyHeader: "X-Forwarded-For", + LogLevel: "info", +} + +// CORSConfig 跨域配置 +var CORSConfig = &Cors{ + AllowOrigins: []string{"UNSET"}, + AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"}, + AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"}, + AllowCredentials: false, + ExposeHeaders: nil, + SameSite: "Default", + Secure: false, +} + +// SlaveConfig 从机配置 +var SlaveConfig = &Slave{ + CallbackTimeout: 20, + SignatureTTL: 600, +} + +var SSLConfig = &SSL{ + Listen: ":443", + CertPath: "", + KeyPath: "", +} + +var UnixConfig = &Unix{ + Listen: "", +} + +var OptionOverwrite = map[string]interface{}{} diff --git a/pkg/conf/version.go b/pkg/conf/version.go deleted file mode 100644 index 6720e8ca..00000000 --- a/pkg/conf/version.go +++ /dev/null @@ -1,16 +0,0 @@ -package conf - -// BackendVersion 当前后端版本号 -var BackendVersion = "3.8.3" - -// RequiredDBVersion 与当前版本匹配的数据库版本 -var RequiredDBVersion = "3.8.1" - -// RequiredStaticVersion 与当前版本匹配的静态资源版本 -var RequiredStaticVersion = "3.8.3" - -// IsPro 是否为Pro版本 -var IsPro = "false" - -// LastCommit 最后commit id -var LastCommit = "a11f819" diff --git a/pkg/credmanager/credmanager.go b/pkg/credmanager/credmanager.go new file mode 100644 index 00000000..b47deb0c --- /dev/null +++ b/pkg/credmanager/credmanager.go @@ -0,0 +1,238 @@ +package credmanager + +import ( + "context" + "encoding/gob" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" +) + +type ( + // CredManager is a centralized for all Oauth tokens that requires periodic refresh + // It is primarily used by OneDrive storage policy. + CredManager interface { + // Obtain gets a credential from the manager, refresh it if it's expired + Obtain(ctx context.Context, key string) (Credential, error) + // Upsert inserts or updates a credential in the manager + Upsert(ctx context.Context, cred ...Credential) error + RefreshAll(ctx context.Context) + } + + Credential interface { + String() string + Refresh(ctx context.Context) (Credential, error) + Key() string + Expiry() time.Time + RefreshedAt() *time.Time + } +) + +func init() { + gob.Register(CredentialResponse{}) +} + +func New(kv cache.Driver) CredManager { + return &credManager{ + kv: kv, + locks: make(map[string]*sync.Mutex), + } +} + +type ( + credManager struct { + kv cache.Driver + mu sync.RWMutex + + locks map[string]*sync.Mutex + } +) + +var ( + ErrNotFound = errors.New("credential not found") +) + +func (m *credManager) Upsert(ctx context.Context, cred ...Credential) error { + m.mu.Lock() + defer m.mu.Unlock() + + l := logging.FromContext(ctx) + for _, c := range cred { + l.Info("CredManager: Upsert credential for key %q...", c.Key()) + if err := m.kv.Set(c.Key(), c, 0); err != nil { + return fmt.Errorf("failed to update credential in KV for key %q: %w", c.Key(), err) + } + + if _, ok := m.locks[c.Key()]; !ok { + m.locks[c.Key()] = &sync.Mutex{} + } + } + + return nil +} + +func (m *credManager) Obtain(ctx context.Context, key string) (Credential, error) { + m.mu.RLock() + itemRaw, ok := m.kv.Get(key) + if !ok { + m.mu.RUnlock() + return nil, fmt.Errorf("credential not found for key %q: %w", key, ErrNotFound) + } + + l := logging.FromContext(ctx) + + item := itemRaw.(Credential) + if _, ok := m.locks[key]; !ok { + m.locks[key] = &sync.Mutex{} + } + m.locks[key].Lock() + defer m.locks[key].Unlock() + m.mu.RUnlock() + + if item.Expiry().After(time.Now()) { + // Credential is still valid + return item, nil + } + + // Credential is expired, refresh it + l.Info("Refreshing credential for key %q...", key) + newCred, err := item.Refresh(ctx) + if err != nil { + return nil, fmt.Errorf("failed to refresh credential for key %q: %w", key, err) + } + + l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String()) + if err := m.kv.Set(key, newCred, 0); err != nil { + return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err) + } + + return newCred, nil +} + +func (m *credManager) RefreshAll(ctx context.Context) { + m.mu.RLock() + defer m.mu.RUnlock() + + l := logging.FromContext(ctx) + for key := range m.locks { + l.Info("Refreshing credential for key %q...", key) + m.locks[key].Lock() + defer m.locks[key].Unlock() + + itemRaw, ok := m.kv.Get(key) + if !ok { + l.Warning("Credential not found for key %q", key) + continue + } + + item := itemRaw.(Credential) + newCred, err := item.Refresh(ctx) + if err != nil { + l.Warning("Failed to refresh credential for key %q: %s", key, err) + continue + } + + l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String()) + if err := m.kv.Set(key, newCred, 0); err != nil { + l.Warning("Failed to update credential in KV for key %q: %s", key, err) + } + } +} + +type ( + slaveCredManager struct { + kv cache.Driver + client request.Client + } + + CredentialResponse struct { + Token string `json:"token"` + ExpireAt time.Time `json:"expire_at"` + } +) + +func NewSlaveManager(kv cache.Driver, config conf.ConfigProvider) CredManager { + return &slaveCredManager{ + kv: kv, + client: request.NewClient( + config, + request.WithCredential(auth.HMACAuth{ + []byte(config.Slave().Secret), + }, int64(config.Slave().SignatureTTL)), + ), + } +} + +func (c CredentialResponse) String() string { + return c.Token +} + +func (c CredentialResponse) Refresh(ctx context.Context) (Credential, error) { + return c, nil +} + +func (c CredentialResponse) Key() string { + return "" +} + +func (c CredentialResponse) Expiry() time.Time { + return c.ExpireAt +} + +func (c CredentialResponse) RefreshedAt() *time.Time { + return nil +} + +func (m *slaveCredManager) Upsert(ctx context.Context, cred ...Credential) error { + return nil +} + +func (m *slaveCredManager) Obtain(ctx context.Context, key string) (Credential, error) { + itemRaw, ok := m.kv.Get(key) + if !ok { + return m.requestCredFromMaster(ctx, key) + } + + return itemRaw.(Credential), nil +} + +// No op on slave node +func (m *slaveCredManager) RefreshAll(ctx context.Context) {} + +func (m *slaveCredManager) requestCredFromMaster(ctx context.Context, key string) (Credential, error) { + l := logging.FromContext(ctx) + l.Info("SlaveCredManager: Requesting credential for key %q from master...", key) + + requestDst := routes.MasterGetCredentialUrl(cluster.MasterSiteUrlFromContext(ctx), key) + resp, err := m.client.Request( + http.MethodGet, + requestDst.String(), + nil, + request.WithContext(ctx), + request.WithLogger(l), + request.WithSlaveMeta(cluster.NodeIdFromContext(ctx)), + request.WithCorrelationID(), + ).CheckHTTPResponse(http.StatusOK).DecodeResponse() + if err != nil { + return nil, fmt.Errorf("failed to request credential from master: %w", err) + } + + cred := &CredentialResponse{} + resp.GobDecode(&cred) + + if err := m.kv.Set(key, *cred, max(int(time.Until(cred.Expiry()).Seconds()), 1)); err != nil { + return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err) + } + + return cred, nil +} diff --git a/pkg/crontab/collect.go b/pkg/crontab/collect.go deleted file mode 100644 index a5678f62..00000000 --- a/pkg/crontab/collect.go +++ /dev/null @@ -1,99 +0,0 @@ -package crontab - -import ( - "context" - "os" - "path/filepath" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -func garbageCollect() { - // 清理打包下载产生的临时文件 - collectArchiveFile() - - // 清理过期的内置内存缓存 - if store, ok := cache.Store.(*cache.MemoStore); ok { - collectCache(store) - } - - util.Log().Info("Crontab job \"cron_garbage_collect\" complete.") -} - -func collectArchiveFile() { - // 读取有效期、目录设置 - tempPath := util.RelativePath(model.GetSettingByName("temp_path")) - expires := model.GetIntSetting("download_timeout", 30) - - // 列出文件 - root := filepath.Join(tempPath, "archive") - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - if err == nil && !info.IsDir() && - strings.HasPrefix(filepath.Base(path), "archive_") && - time.Now().Sub(info.ModTime()).Seconds() > float64(expires) { - util.Log().Debug("Delete expired batch download temp file %q.", path) - // 删除符合条件的文件 - if err := os.Remove(path); err != nil { - util.Log().Debug("Failed to delete temp file %q: %s", path, err) - } - } - return nil - }) - - if err != nil { - util.Log().Debug("Crontab job cannot list temp batch download folder: %s", err) - } - -} - -func collectCache(store *cache.MemoStore) { - util.Log().Debug("Cleanup memory cache.") - store.GarbageCollect() -} - -func uploadSessionCollect() { - placeholders := model.GetUploadPlaceholderFiles(0) - - // 将过期的上传会话按照用户分组 - userToFiles := make(map[uint][]uint) - for _, file := range placeholders { - _, sessionExist := cache.Get(filesystem.UploadSessionCachePrefix + *file.UploadSessionID) - if sessionExist { - continue - } - - if _, ok := userToFiles[file.UserID]; !ok { - userToFiles[file.UserID] = make([]uint, 0) - } - - userToFiles[file.UserID] = append(userToFiles[file.UserID], file.ID) - } - - // 删除过期的会话 - for uid, filesIDs := range userToFiles { - user, err := model.GetUserByID(uid) - if err != nil { - util.Log().Warning("Owner of the upload session cannot be found: %s", err) - continue - } - - fs, err := filesystem.NewFileSystem(&user) - if err != nil { - util.Log().Warning("Failed to initialize filesystem: %s", err) - continue - } - - if err = fs.Delete(context.Background(), []uint{}, filesIDs, false, false); err != nil { - util.Log().Warning("Failed to delete upload session: %s", err) - } - - fs.Recycle() - } - - util.Log().Info("Crontab job \"cron_recycle_upload_session\" complete.") -} diff --git a/pkg/crontab/crontab.go b/pkg/crontab/crontab.go new file mode 100644 index 00000000..f7febc48 --- /dev/null +++ b/pkg/crontab/crontab.go @@ -0,0 +1,73 @@ +package crontab + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gofrs/uuid" + "github.com/robfig/cron/v3" +) + +type ( + CronTaskFunc func(ctx context.Context) + cornRegistration struct { + t setting.CronType + config string + fn CronTaskFunc + } +) + +var ( + registrations []cornRegistration +) + +// Register registers a cron task. +func Register(t setting.CronType, fn CronTaskFunc) { + registrations = append(registrations, cornRegistration{ + t: t, + fn: fn, + }) +} + +// NewCron constructs a new cron instance with given dependency. +func NewCron(ctx context.Context, dep dependency.Dep) (*cron.Cron, error) { + settings := dep.SettingProvider() + userClient := dep.UserClient() + anonymous, err := userClient.AnonymousUser(ctx) + if err != nil { + return nil, fmt.Errorf("cron: faield to get anonymous user: %w", err) + } + + l := dep.Logger() + l.Info("Initialize crontab jobs...") + c := cron.New() + + for _, r := range registrations { + cronConfig := settings.Cron(ctx, r.t) + if _, err := c.AddFunc(cronConfig, taskWrapper(string(r.t), cronConfig, anonymous, dep, r.fn)); err != nil { + l.Warning("Failed to start crontab job %q: %s", cronConfig, err) + } + } + + return c, nil +} + +func taskWrapper(name, config string, user *ent.User, dep dependency.Dep, task CronTaskFunc) func() { + l := dep.Logger() + l.Info("Cron task %s started with config %q", name, config) + return func() { + cid := uuid.Must(uuid.NewV4()) + l.Info("Executing Cron task %q with Cid %q", name, cid) + ctx := context.Background() + l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s Cron: %s]", cid, name)) + ctx = dep.ForkWithLogger(ctx, l) + ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid) + ctx = context.WithValue(ctx, logging.LoggerCtx{}, l) + ctx = context.WithValue(ctx, inventory.UserCtx{}, user) + task(ctx) + } +} diff --git a/pkg/crontab/init.go b/pkg/crontab/init.go deleted file mode 100644 index 5971c2c8..00000000 --- a/pkg/crontab/init.go +++ /dev/null @@ -1,47 +0,0 @@ -package crontab - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/robfig/cron/v3" -) - -// Cron 定时任务 -var Cron *cron.Cron - -// Reload 重新启动定时任务 -func Reload() { - if Cron != nil { - Cron.Stop() - } - Init() -} - -// Init 初始化定时任务 -func Init() { - util.Log().Info("Initialize crontab jobs...") - // 读取cron日程设置 - options := model.GetSettingByNames( - "cron_garbage_collect", - "cron_recycle_upload_session", - ) - Cron := cron.New() - for k, v := range options { - var handler func() - switch k { - case "cron_garbage_collect": - handler = garbageCollect - case "cron_recycle_upload_session": - handler = uploadSessionCollect - default: - util.Log().Warning("Unknown crontab job type %q, skipping...", k) - continue - } - - if _, err := Cron.AddFunc(v, handler); err != nil { - util.Log().Warning("Failed to start crontab job %q: %s", k, err) - } - - } - Cron.Start() -} diff --git a/pkg/downloader/aria2/aria2.go b/pkg/downloader/aria2/aria2.go new file mode 100644 index 00000000..81d0d10f --- /dev/null +++ b/pkg/downloader/aria2/aria2.go @@ -0,0 +1,283 @@ +package aria2 + +import ( + "context" + "fmt" + "net/url" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2/rpc" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +const ( + Aria2TempFolder = "aria2" + deleteTempFileDuration = 120 * time.Second +) + +type aria2Client struct { + l logging.Logger + settings setting.Provider + + options *types.Aria2Setting + timeout time.Duration + caller rpc.Client +} + +func New(l logging.Logger, settings setting.Provider, options *types.Aria2Setting) downloader.Downloader { + rpcServer := options.Server + rpcUrl, err := url.Parse(options.Server) + if err == nil { + // add /jsonrpc to the url if not present + rpcUrl.Path = "/jsonrpc" + rpcServer = rpcUrl.String() + } + + options.Server = rpcServer + return &aria2Client{ + l: l, + settings: settings, + options: options, + timeout: time.Duration(10) * time.Second, + } +} + +func (a *aria2Client) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) { + caller := a.caller + if caller == nil { + var err error + caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil) + if err != nil { + return nil, fmt.Errorf("cannot create rpc client: %w", err) + } + } + + path := a.tempPath(ctx) + a.l.Info("Creating aria2 task with url %q saving to %q...", url, path) + + // Create the download task options + downloadOptions := map[string]interface{}{} + for k, v := range a.options.Options { + downloadOptions[k] = v + } + for k, v := range options { + downloadOptions[k] = v + } + downloadOptions["dir"] = path + downloadOptions["follow-torrent"] = "mem" + + gid, err := caller.AddURI(url, downloadOptions) + if err != nil || gid == "" { + return nil, err + } + + return &downloader.TaskHandle{ + ID: gid, + }, nil +} + +func (a *aria2Client) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) { + caller := a.caller + if caller == nil { + var err error + caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil) + if err != nil { + return nil, fmt.Errorf("cannot create rpc client: %w", err) + } + } + + status, err := caller.TellStatus(handle.ID) + if err != nil { + return nil, fmt.Errorf("aria2 rpc error: %w", err) + } + + state := downloader.StatusDownloading + switch status.Status { + case "active": + if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength { + state = downloader.StatusSeeding + } else { + state = downloader.StatusDownloading + } + case "waiting", "paused": + state = downloader.StatusDownloading + case "complete": + state = downloader.StatusCompleted + case "error": + state = downloader.StatusError + case "cancelled", "removed": + a.l.Debug("Task %q is cancelled", handle.ID) + return nil, fmt.Errorf("Task canceled: %w", downloader.ErrTaskNotFount) + } + + totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64) + downloaded, _ := strconv.ParseInt(status.CompletedLength, 10, 64) + downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64) + uploaded, _ := strconv.ParseInt(status.UploadLength, 10, 64) + uploadSpeed, _ := strconv.ParseInt(status.UploadSpeed, 10, 64) + numPieces, _ := strconv.Atoi(status.NumPieces) + savePath := filepath.ToSlash(status.Dir) + + res := &downloader.TaskStatus{ + State: state, + Name: status.BitTorrent.Info.Name, + Total: totalLength, + Downloaded: downloaded, + DownloadSpeed: downloadSpeed, + Uploaded: uploaded, + UploadSpeed: uploadSpeed, + SavePath: savePath, + NumPieces: numPieces, + Hash: status.InfoHash, + Files: lo.Map(status.Files, func(item rpc.FileInfo, index int) downloader.TaskFile { + index, _ = strconv.Atoi(item.Index) + size, _ := strconv.ParseInt(item.Length, 10, 64) + completed, _ := strconv.ParseInt(item.CompletedLength, 10, 64) + relPath := strings.TrimPrefix(filepath.ToSlash(item.Path), savePath) + // Remove first letter if any + if len(relPath) > 0 { + relPath = relPath[1:] + } + progress := 0.0 + if size > 0 { + progress = float64(completed) / float64(size) + } + return downloader.TaskFile{ + Index: index, + Name: relPath, + Size: size, + Progress: progress, + Selected: item.Selected == "true", + } + }), + } + + if len(status.FollowedBy) > 0 { + res.FollowedBy = &downloader.TaskHandle{ + ID: status.FollowedBy[0], + } + } + + if len(status.Files) == 1 && res.Name == "" { + res.Name = path.Base(filepath.ToSlash(status.Files[0].Path)) + } + + if status.BitField != "" { + res.Pieces = make([]byte, len(status.BitField)/2) + // Convert hex string to bytes + for i := 0; i < len(status.BitField); i += 2 { + b, _ := strconv.ParseInt(status.BitField[i:i+2], 16, 8) + res.Pieces[i/2] = byte(b) + } + } + + return res, nil +} + +func (a *aria2Client) Cancel(ctx context.Context, handle *downloader.TaskHandle) error { + caller := a.caller + if caller == nil { + var err error + caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil) + if err != nil { + return fmt.Errorf("cannot create rpc client: %w", err) + } + } + + status, err := a.Info(ctx, handle) + if err != nil { + return fmt.Errorf("cannot get task: %w", err) + } + + // Delay to delete temp download folder to avoid being locked by aria2 + defer func() { + go func(parent string, l logging.Logger) { + time.Sleep(deleteTempFileDuration) + err := os.RemoveAll(parent) + if err != nil { + l.Warning("Failed to delete temp download folder: %q: %s", parent, err) + } + }(status.SavePath, a.l) + }() + + if _, err := caller.Remove(handle.ID); err != nil { + return fmt.Errorf("aria2 rpc error: %w", err) + } + + return nil +} + +func (a *aria2Client) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error { + caller := a.caller + if caller == nil { + var err error + caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil) + if err != nil { + return fmt.Errorf("cannot create rpc client: %w", err) + } + } + + status, err := a.Info(ctx, handle) + if err != nil { + return fmt.Errorf("cannot get task: %w", err) + } + + selected := lo.SliceToMap(status.Files, func(item downloader.TaskFile) (int, bool) { + return item.Index, true + }) + for _, arg := range args { + if !arg.Download { + delete(selected, arg.Index) + } + } + + _, err = caller.ChangeOption(handle.ID, map[string]interface{}{"select-file": strings.Join(lo.MapToSlice(selected, func(key int, value bool) string { + return strconv.Itoa(key) + }), ",")}) + return err +} + +func (a *aria2Client) Test(ctx context.Context) (string, error) { + caller := a.caller + if caller == nil { + var err error + caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil) + if err != nil { + return "", fmt.Errorf("cannot create rpc client: %w", err) + } + } + + version, err := caller.GetVersion() + if err != nil { + return "", fmt.Errorf("cannot call aria2: %w", err) + } + + return version.Version, nil +} + +func (a *aria2Client) tempPath(ctx context.Context) string { + guid, _ := uuid.NewV4() + + // Generate a unique path for the task + base := util.RelativePath(a.options.TempPath) + if a.options.TempPath == "" { + base = util.DataPath(a.settings.TempPath(ctx)) + } + path := filepath.Join( + base, + Aria2TempFolder, + guid.String(), + ) + return path +} diff --git a/pkg/aria2/rpc/README.md b/pkg/downloader/aria2/rpc/README.md similarity index 100% rename from pkg/aria2/rpc/README.md rename to pkg/downloader/aria2/rpc/README.md diff --git a/pkg/aria2/rpc/call.go b/pkg/downloader/aria2/rpc/call.go similarity index 100% rename from pkg/aria2/rpc/call.go rename to pkg/downloader/aria2/rpc/call.go diff --git a/pkg/aria2/rpc/client.go b/pkg/downloader/aria2/rpc/client.go similarity index 95% rename from pkg/aria2/rpc/client.go rename to pkg/downloader/aria2/rpc/client.go index adb9e397..bcce827e 100644 --- a/pkg/aria2/rpc/client.go +++ b/pkg/downloader/aria2/rpc/client.go @@ -268,8 +268,9 @@ func (c *client) TellStatus(gid string, keys ...string) (info StatusInfo, err er // `aria2.getUris([secret, ]gid)` // This method returns the URIs used in the download denoted by gid (string). // The response is an array of structs and it contains following keys. Values are string. -// uri URI -// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue. +// +// uri URI +// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue. func (c *client) GetURIs(gid string) (infos []URIInfo, err error) { params := make([]interface{}, 0, 2) if c.token != "" { @@ -456,12 +457,14 @@ func (c *client) GetOption(gid string) (m Option, err error) { // `aria2.changeOption([secret, ]gid, options)` // This method changes options of the download denoted by gid (string) dynamically. options is a struct. // The following options are available for active downloads: -// bt-max-peers -// bt-request-peer-speed-limit -// bt-remove-unselected-file -// force-save -// max-download-limit -// max-upload-limit +// +// bt-max-peers +// bt-request-peer-speed-limit +// bt-remove-unselected-file +// force-save +// max-download-limit +// max-upload-limit +// // For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option. // This method returns OK for success. func (c *client) ChangeOption(gid string, option Option) (ok string, err error) { @@ -496,17 +499,19 @@ func (c *client) GetGlobalOption() (m Option, err error) { // This method changes global options dynamically. // options is a struct. // The following options are available: -// bt-max-open-files -// download-result -// log -// log-level -// max-concurrent-downloads -// max-download-result -// max-overall-download-limit -// max-overall-upload-limit -// save-cookies -// save-session -// server-stat-of +// +// bt-max-open-files +// download-result +// log +// log-level +// max-concurrent-downloads +// max-download-result +// max-overall-download-limit +// max-overall-upload-limit +// save-cookies +// save-session +// server-stat-of +// // In addition, options listed in the Input File subsection are available, except for following options: checksum, index-out, out, pause and select-file. // With the log option, you can dynamically start logging or change log file. // To stop logging, specify an empty string("") as the parameter value. @@ -525,13 +530,14 @@ func (c *client) ChangeGlobalOption(options Option) (ok string, err error) { // `aria2.getGlobalStat([secret])` // This method returns global statistics such as the overall download and upload speeds. // The response is a struct and contains the following keys. Values are strings. -// downloadSpeed Overall download speed (byte/sec). -// uploadSpeed Overall upload speed(byte/sec). -// numActive The number of active downloads. -// numWaiting The number of waiting downloads. -// numStopped The number of stopped downloads in the current session. -// This value is capped by the --max-download-result option. -// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option. +// +// downloadSpeed Overall download speed (byte/sec). +// uploadSpeed Overall upload speed(byte/sec). +// numActive The number of active downloads. +// numWaiting The number of waiting downloads. +// numStopped The number of stopped downloads in the current session. +// This value is capped by the --max-download-result option. +// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option. func (c *client) GetGlobalStat() (info GlobalStatInfo, err error) { params := []string{} if c.token != "" { @@ -569,8 +575,9 @@ func (c *client) RemoveDownloadResult(gid string) (ok string, err error) { // `aria2.getVersion([secret])` // This method returns the version of aria2 and the list of enabled features. // The response is a struct and contains following keys. -// version Version number of aria2 as a string. -// enabledFeatures List of enabled features. Each feature is given as a string. +// +// version Version number of aria2 as a string. +// enabledFeatures List of enabled features. Each feature is given as a string. func (c *client) GetVersion() (info VersionInfo, err error) { params := []string{} if c.token != "" { @@ -583,7 +590,8 @@ func (c *client) GetVersion() (info VersionInfo, err error) { // `aria2.getSessionInfo([secret])` // This method returns session information. // The response is a struct and contains following key. -// sessionId Session ID, which is generated each time when aria2 is invoked. +// +// sessionId Session ID, which is generated each time when aria2 is invoked. func (c *client) GetSessionInfo() (info SessionInfo, err error) { params := []string{} if c.token != "" { diff --git a/pkg/aria2/rpc/const.go b/pkg/downloader/aria2/rpc/const.go similarity index 100% rename from pkg/aria2/rpc/const.go rename to pkg/downloader/aria2/rpc/const.go diff --git a/pkg/aria2/rpc/json2.go b/pkg/downloader/aria2/rpc/json2.go similarity index 100% rename from pkg/aria2/rpc/json2.go rename to pkg/downloader/aria2/rpc/json2.go diff --git a/pkg/aria2/rpc/notification.go b/pkg/downloader/aria2/rpc/notification.go similarity index 100% rename from pkg/aria2/rpc/notification.go rename to pkg/downloader/aria2/rpc/notification.go diff --git a/pkg/aria2/rpc/proc.go b/pkg/downloader/aria2/rpc/proc.go similarity index 100% rename from pkg/aria2/rpc/proc.go rename to pkg/downloader/aria2/rpc/proc.go diff --git a/pkg/aria2/rpc/proto.go b/pkg/downloader/aria2/rpc/proto.go similarity index 100% rename from pkg/aria2/rpc/proto.go rename to pkg/downloader/aria2/rpc/proto.go diff --git a/pkg/aria2/rpc/resp.go b/pkg/downloader/aria2/rpc/resp.go similarity index 100% rename from pkg/aria2/rpc/resp.go rename to pkg/downloader/aria2/rpc/resp.go diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go new file mode 100644 index 00000000..17a44fc6 --- /dev/null +++ b/pkg/downloader/downloader.go @@ -0,0 +1,76 @@ +package downloader + +import ( + "context" + "encoding/gob" + "fmt" +) + +var ( + ErrTaskNotFount = fmt.Errorf("task not found") +) + +type ( + Downloader interface { + // Create a task with the given URL and options overwriting the default settings, returns a task handle for future operations. + CreateTask(ctx context.Context, url string, options map[string]interface{}) (*TaskHandle, error) + // Info returns the status of the task with the given handle. + Info(ctx context.Context, handle *TaskHandle) (*TaskStatus, error) + // Cancel the task with the given handle. + Cancel(ctx context.Context, handle *TaskHandle) error + // SetFilesToDownload sets the files to download for the task with the given handle. + SetFilesToDownload(ctx context.Context, handle *TaskHandle, args ...*SetFileToDownloadArgs) error + // Test tests the connection to the downloader. + Test(ctx context.Context) (string, error) + } + + // TaskHandle represents a task handle for future operations + TaskHandle struct { + ID string `json:"id"` + Hash string `json:"hash"` + } + Status string + TaskStatus struct { + FollowedBy *TaskHandle `json:"-"` // Indicate if the task handle is changed + SavePath string `json:"save_path,omitempty"` + Name string `json:"name"` + State Status `json:"state"` + Total int64 `json:"total"` + Downloaded int64 `json:"downloaded"` + DownloadSpeed int64 `json:"download_speed"` + Uploaded int64 `json:"uploaded"` + UploadSpeed int64 `json:"upload_speed"` + Hash string `json:"hash,omitempty"` + Files []TaskFile `json:"files,omitempty"` + Pieces []byte `json:"pieces,omitempty"` // Hexadecimal representation of the download progress of the peer. The highest bit corresponds to the piece at index 0. + NumPieces int `json:"num_pieces,omitempty"` + } + + TaskFile struct { + Index int `json:"index"` + Name string `json:"name"` + Size int64 `json:"size"` + Progress float64 `json:"progress"` + Selected bool `json:"selected"` + } + + SetFileToDownloadArgs struct { + Index int `json:"index"` + Download bool `json:"download"` + } +) + +const ( + StatusDownloading Status = "downloading" + StatusSeeding Status = "seeding" + StatusCompleted Status = "completed" + StatusError Status = "error" + StatusUnknown Status = "unknown" + + DownloaderCtxKey = "downloader" +) + +func init() { + gob.Register(TaskHandle{}) + gob.Register(TaskStatus{}) +} diff --git a/pkg/downloader/qbittorrent/qbittorrent.go b/pkg/downloader/qbittorrent/qbittorrent.go new file mode 100644 index 00000000..85a04840 --- /dev/null +++ b/pkg/downloader/qbittorrent/qbittorrent.go @@ -0,0 +1,395 @@ +package qbittorrent + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "mime/multipart" + "net/http" + "net/http/cookiejar" + "net/url" + "path/filepath" + "strings" + + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +const ( + apiPrefix = "/api/v2" + successResponse = "Ok." + crTagPrefix = "cr-" + + downloadPrioritySkip = 0 + downloadPriorityDownload = 1 +) + +var ( + supportDownloadOptions = map[string]bool{ + "cookie": true, + "skip_checking": true, + "root_folder": true, + "rename": true, + "upLimit": true, + "dlLimit": true, + "ratioLimit": true, + "seedingTimeLimit": true, + "autoTMM": true, + "sequentialDownload": true, + "firstLastPiecePrio": true, + } +) + +type qbittorrentClient struct { + c request.Client + settings setting.Provider + l logging.Logger + options *types.QBittorrentSetting +} + +func NewClient(l logging.Logger, c request.Client, setting setting.Provider, options *types.QBittorrentSetting) (downloader.Downloader, error) { + jar, err := cookiejar.New(nil) + if err != nil { + return nil, err + } + + server, err := url.Parse(options.Server) + if err != nil { + return nil, fmt.Errorf("invalid qbittorrent server URL: %w", err) + } + + base, _ := url.Parse(apiPrefix) + c.Apply( + request.WithCookieJar(jar), + request.WithLogger(l), + request.WithEndpoint(options.Server), + request.WithEndpoint(server.ResolveReference(base).String()), + ) + return &qbittorrentClient{c: c, options: options, l: l, settings: setting}, nil +} + +func (c *qbittorrentClient) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error { + downloadId := make([]int, 0, len(args)) + skipId := make([]int, 0, len(args)) + for _, arg := range args { + if arg.Download { + downloadId = append(downloadId, arg.Index) + } else { + skipId = append(skipId, arg.Index) + } + } + + if len(downloadId) > 0 { + if err := c.setFilePriority(ctx, handle.Hash, downloadPriorityDownload, downloadId...); err != nil { + return fmt.Errorf("failed to set file priority to download: %w", err) + } + } + + if len(skipId) > 0 { + if err := c.setFilePriority(ctx, handle.Hash, downloadPrioritySkip, skipId...); err != nil { + return fmt.Errorf("failed to set file priority to skip: %w", err) + } + } + + return nil +} + +func (c *qbittorrentClient) Cancel(ctx context.Context, handle *downloader.TaskHandle) error { + buffer := bytes.Buffer{} + formWriter := multipart.NewWriter(&buffer) + _ = formWriter.WriteField("hashes", handle.Hash) + _ = formWriter.WriteField("deleteFiles", "true") + + headers := http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + _, err := c.request(ctx, http.MethodPost, "torrents/delete", buffer.String(), &headers) + if err != nil { + return fmt.Errorf("failed to cancel task with hash %q: %w", handle.Hash, err) + } + + // Delete tags + buffer = bytes.Buffer{} + formWriter = multipart.NewWriter(&buffer) + _ = formWriter.WriteField("tags", crTagPrefix+handle.ID) + + headers = http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + _, err = c.request(ctx, http.MethodPost, "torrents/deleteTags", buffer.String(), &headers) + if err != nil { + return fmt.Errorf("failed to delete tag with id %q: %w", handle.ID, err) + } + + return nil +} + +func (c *qbittorrentClient) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) { + buffer := bytes.Buffer{} + formWriter := multipart.NewWriter(&buffer) + _ = formWriter.WriteField("tag", crTagPrefix+handle.ID) + + headers := http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + // Get task info + resp, err := c.request(ctx, http.MethodPost, "torrents/info", buffer.String(), &headers) + if err != nil { + return nil, fmt.Errorf("failed to get task info with tag %q: %w", crTagPrefix+handle.ID, err) + } + + var torrents []Torrent + if err := json.Unmarshal([]byte(resp), &torrents); err != nil { + return nil, fmt.Errorf("failed to unmarshal info response: %w", err) + } + + if len(torrents) == 0 { + return nil, fmt.Errorf("no torrent under tag %q: %w", crTagPrefix+handle.ID, downloader.ErrTaskNotFount) + } + + // Get file info + buffer = bytes.Buffer{} + formWriter = multipart.NewWriter(&buffer) + _ = formWriter.WriteField("hash", torrents[0].Hash) + headers = http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + resp, err = c.request(ctx, http.MethodPost, "torrents/files", buffer.String(), &headers) + if err != nil { + return nil, fmt.Errorf("failed to get torrent files with hash %q: %w", torrents[0].Hash, err) + } + + var files []File + if err := json.Unmarshal([]byte(resp), &files); err != nil { + return nil, fmt.Errorf("failed to unmarshal files response: %w", err) + } + + // Get piece status + resp, err = c.request(ctx, http.MethodPost, "torrents/pieceStates", buffer.String(), &headers) + if err != nil { + return nil, fmt.Errorf("failed to get torrent pieceStates with hash %q: %w", torrents[0].Hash, err) + } + + var pieceStates []int + if err := json.Unmarshal([]byte(resp), &pieceStates); err != nil { + return nil, fmt.Errorf("failed to unmarshal pieceStates response: %w", err) + } + + // Combining and converting all info + state := downloader.StatusDownloading + switch torrents[0].State { + case "downloading", "pausedDL", "allocating", "metaDL", "queuedDL", "stalledDL", "checkingDL", "forcedDL", "checkingResumeData", "moving": + state = downloader.StatusDownloading + case "uploading", "queuedUP", "stalledUP", "checkingUP": + state = downloader.StatusSeeding + case "pausedUP": + state = downloader.StatusCompleted + case "error", "missingFiles": + state = downloader.StatusError + default: + state = downloader.StatusUnknown + } + status := &downloader.TaskStatus{ + Name: torrents[0].Name, + Total: torrents[0].Size, + Downloaded: torrents[0].Completed, + DownloadSpeed: torrents[0].Dlspeed, + Uploaded: torrents[0].Uploaded, + UploadSpeed: torrents[0].Upspeed, + SavePath: filepath.ToSlash(torrents[0].SavePath), + State: state, + Hash: torrents[0].Hash, + Files: lo.Map(files, func(item File, index int) downloader.TaskFile { + return downloader.TaskFile{ + Index: item.Index, + Name: filepath.ToSlash(item.Name), + Size: item.Size, + Progress: item.Progress, + Selected: item.Priority > 0, + } + }), + } + + if handle.Hash != torrents[0].Hash { + handle.Hash = torrents[0].Hash + status.FollowedBy = handle + } + + // Convert piece states to hex bytes array, The highest bit corresponds to the piece at index 0. + status.NumPieces = len(pieceStates) + pieces := make([]byte, 0, len(pieceStates)/8+1) + for i := 0; i < len(pieceStates); i += 8 { + var b byte + for j := 0; j < 8; j++ { + if i+j >= len(pieceStates) { + break + } + pieceStatus := 0 + if pieceStates[i+j] == 2 { + pieceStatus = 1 + } + b |= byte(pieceStatus) << uint(7-j) + } + pieces = append(pieces, b) + } + status.Pieces = pieces + + return status, nil +} + +func (c *qbittorrentClient) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) { + guid, _ := uuid.NewV4() + + // Generate a unique path for the task + base := util.RelativePath(c.options.TempPath) + if c.options.TempPath == "" { + base = util.DataPath(c.settings.TempPath(ctx)) + } + path := filepath.Join( + base, + "qbittorrent", + guid.String(), + ) + c.l.Info("Creating QBitTorrent task with url %q saving to %q...", url, path) + + buffer := bytes.Buffer{} + formWriter := multipart.NewWriter(&buffer) + _ = formWriter.WriteField("urls", url) + _ = formWriter.WriteField("savepath", path) + _ = formWriter.WriteField("tags", crTagPrefix+guid.String()) + + // Apply global options + for k, v := range c.options.Options { + if _, ok := supportDownloadOptions[k]; ok { + _ = formWriter.WriteField(k, fmt.Sprintf("%s", v)) + } + } + + // Apply group options + for k, v := range options { + if _, ok := supportDownloadOptions[k]; ok { + _ = formWriter.WriteField(k, fmt.Sprintf("%s", v)) + } + } + + // Send request + headers := http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + resp, err := c.request(ctx, http.MethodPost, "torrents/add", buffer.String(), &headers) + if err != nil { + return nil, fmt.Errorf("create task qbittorrent failed: %w", err) + } + + if resp != successResponse { + return nil, fmt.Errorf("create task qbittorrent failed: %s", resp) + } + + return &downloader.TaskHandle{ + ID: guid.String(), + }, nil +} + +func (c *qbittorrentClient) setFilePriority(ctx context.Context, hash string, priority int, id ...int) error { + buffer := bytes.Buffer{} + formWriter := multipart.NewWriter(&buffer) + _ = formWriter.WriteField("hash", hash) + _ = formWriter.WriteField("id", strings.Join( + lo.Map(id, func(item int, index int) string { + return fmt.Sprintf("%d", item) + }), "|")) + _ = formWriter.WriteField("priority", fmt.Sprintf("%d", priority)) + + headers := http.Header{ + "Content-Type": []string{formWriter.FormDataContentType()}, + } + + _, err := c.request(ctx, http.MethodPost, "torrents/filePrio", buffer.String(), &headers) + if err != nil { + return fmt.Errorf("failed to set file priority: %w", err) + } + + return nil +} + +func (c *qbittorrentClient) Test(ctx context.Context) (string, error) { + res, err := c.request(ctx, http.MethodGet, "app/version", "", nil) + if err != nil { + return "", fmt.Errorf("test qbittorrent failed: %w", err) + } + + return res, nil +} + +func (c *qbittorrentClient) login(ctx context.Context) error { + form := url.Values{} + form.Add("username", c.options.User) + form.Add("password", c.options.Password) + res, err := c.c.Request(http.MethodPost, "auth/login", + strings.NewReader(form.Encode()), + request.WithContext(ctx), + request.WithHeader(http.Header{ + "Content-Type": []string{"application/x-www-form-urlencoded"}, + }), + ).CheckHTTPResponse(http.StatusOK).GetResponse() + if err != nil { + return fmt.Errorf("login failed with unexpected status code: %w", err) + } + + if res != successResponse { + return fmt.Errorf("login failed with response: %s, possibly inccorrect credential is provided", res) + } + + return nil +} + +func (c *qbittorrentClient) request(ctx context.Context, method, path string, body string, headers *http.Header) (string, error) { + opts := []request.Option{ + request.WithContext(ctx), + } + + if headers != nil { + opts = append(opts, request.WithHeader(*headers)) + } + + res := c.c.Request(method, path, strings.NewReader(body), opts...) + + if res.Err != nil { + return "", fmt.Errorf("send request failed: %w", res.Err) + } + + switch res.Response.StatusCode { + case http.StatusForbidden: + c.l.Info("QBittorrent cookie expired, sending login request...") + if err := c.login(ctx); err != nil { + return "", fmt.Errorf("login failed: %w", err) + } + + return c.request(ctx, method, path, body, headers) + + case http.StatusOK: + respContent, err := res.GetResponse() + if err != nil { + return "", fmt.Errorf("failed reading response: %w", err) + } + + return respContent, nil + case http.StatusUnsupportedMediaType: + return "", fmt.Errorf("invalid torrent file") + default: + content, _ := res.GetResponse() + return "", fmt.Errorf("unexpected status code: %d, content: %s", res.Response.StatusCode, content) + } +} diff --git a/pkg/downloader/qbittorrent/types.go b/pkg/downloader/qbittorrent/types.go new file mode 100644 index 00000000..319fbec4 --- /dev/null +++ b/pkg/downloader/qbittorrent/types.go @@ -0,0 +1,64 @@ +package qbittorrent + +type Torrent struct { + AddedOn int64 `json:"added_on"` + AmountLeft int64 `json:"amount_left"` + AutoTmm bool `json:"auto_tmm"` + Availability float64 `json:"availability"` + Category string `json:"category"` + Completed int64 `json:"completed"` + CompletionOn int64 `json:"completion_on"` + ContentPath string `json:"content_path"` + DlLimit int `json:"dl_limit"` + Dlspeed int64 `json:"dlspeed"` + DownloadPath string `json:"download_path"` + Downloaded int64 `json:"downloaded"` + DownloadedSession int `json:"downloaded_session"` + Eta int `json:"eta"` + FLPiecePrio bool `json:"f_l_piece_prio"` + ForceStart bool `json:"force_start"` + Hash string `json:"hash"` + InfohashV1 string `json:"infohash_v1"` + InfohashV2 string `json:"infohash_v2"` + LastActivity int `json:"last_activity"` + MagnetUri string `json:"magnet_uri"` + MaxRatio float64 `json:"max_ratio"` + MaxSeedingTime int `json:"max_seeding_time"` + Name string `json:"name"` + NumComplete int `json:"num_complete"` + NumIncomplete int `json:"num_incomplete"` + NumLeechs int `json:"num_leechs"` + NumSeeds int `json:"num_seeds"` + Priority int `json:"priority"` + Progress float64 `json:"progress"` + Ratio float64 `json:"ratio"` + RatioLimit float64 `json:"ratio_limit"` + SavePath string `json:"save_path"` + SeedingTime int `json:"seeding_time"` + SeedingTimeLimit int `json:"seeding_time_limit"` + SeenComplete int `json:"seen_complete"` + SeqDl bool `json:"seq_dl"` + Size int64 `json:"size"` + State string `json:"state"` + SuperSeeding bool `json:"super_seeding"` + Tags string `json:"tags"` + TimeActive int `json:"time_active"` + TotalSize int64 `json:"total_size"` + Tracker string `json:"tracker"` + TrackersCount int `json:"trackers_count"` + UpLimit int `json:"up_limit"` + Uploaded int64 `json:"uploaded"` + UploadedSession int `json:"uploaded_session"` + Upspeed int64 `json:"upspeed"` +} + +type File struct { + Index int `json:"index"` + IsSeed bool `json:"is_seed"` + Name string `json:"name"` + PieceRange []int `json:"piece_range"` + Priority int `json:"priority"` + Progress float64 `json:"progress"` + Size int64 `json:"size"` + Availability float64 `json:"availability"` +} diff --git a/pkg/downloader/slave/slave.go b/pkg/downloader/slave/slave.go new file mode 100644 index 00000000..f573e38f --- /dev/null +++ b/pkg/downloader/slave/slave.go @@ -0,0 +1,258 @@ +package slave + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudflare/cfssl/scan/crypto/sha1" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" +) + +type slaveDownloader struct { + client request.Client + nodeSetting *types.NodeSetting + nodeSettingHash string +} + +// NewSlaveDownloader creates a new slave downloader +func NewSlaveDownloader(client request.Client, nodeSetting *types.NodeSetting) downloader.Downloader { + nodeSettingJson, err := json.Marshal(nodeSetting) + if err != nil { + nodeSettingJson = []byte{} + } + + return &slaveDownloader{ + client: client, + nodeSetting: nodeSetting, + nodeSettingHash: fmt.Sprintf("%x", sha1.Sum(nodeSettingJson)), + } +} + +func (s *slaveDownloader) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) { + reqBody, err := json.Marshal(&CreateSlaveDownload{ + NodeSetting: s.nodeSetting, + Url: url, + Options: options, + NodeSettingHash: s.nodeSettingHash, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := s.client.Request( + "POST", + constants.APIPrefixSlave+"/download/task", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + // 处理列取结果 + if resp.Code != 0 { + return nil, serializer.NewErrorFromResponse(resp) + } + + var taskHandle *downloader.TaskHandle + if resp.GobDecode(&taskHandle); taskHandle != nil { + return taskHandle, nil + } + + return nil, fmt.Errorf("unexpected response data: %v", resp.Data) +} + +func (s *slaveDownloader) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) { + reqBody, err := json.Marshal(&GetSlaveDownload{ + NodeSetting: s.nodeSetting, + Handle: handle, + NodeSettingHash: s.nodeSettingHash, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := s.client.Request( + "POST", + constants.APIPrefixSlave+"/download/status", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + // 处理列取结果 + if resp.Code != 0 { + err = serializer.NewErrorFromResponse(resp) + if strings.Contains(err.Error(), downloader.ErrTaskNotFount.Error()) { + return nil, fmt.Errorf("%s (%w)", err.Error(), downloader.ErrTaskNotFount) + } + return nil, err + } + + var taskStatus *downloader.TaskStatus + if resp.GobDecode(&taskStatus); taskStatus != nil { + return taskStatus, nil + } + + return nil, fmt.Errorf("unexpected response data: %v", resp.Data) +} + +func (s *slaveDownloader) Cancel(ctx context.Context, handle *downloader.TaskHandle) error { + reqBody, err := json.Marshal(&CancelSlaveDownload{ + NodeSetting: s.nodeSetting, + Handle: handle, + NodeSettingHash: s.nodeSettingHash, + }) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := s.client.Request( + "POST", + constants.APIPrefixSlave+"/download/cancel", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (s *slaveDownloader) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error { + reqBody, err := json.Marshal(&SetSlaveFilesToDownload{ + NodeSetting: s.nodeSetting, + Handle: handle, + NodeSettingHash: s.nodeSettingHash, + Args: args, + }) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := s.client.Request( + "POST", + constants.APIPrefixSlave+"/download/select", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + // 处理列取结果 + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (s *slaveDownloader) Test(ctx context.Context) (string, error) { + reqBody, err := json.Marshal(&TestSlaveDownload{ + NodeSetting: s.nodeSetting, + NodeSettingHash: s.nodeSettingHash, + }) + if err != nil { + return "", fmt.Errorf("failed to marshal request body: %w", err) + } + + resp, err := s.client.Request( + "POST", + constants.APIPrefixSlave+"/download/test", + bytes.NewReader(reqBody), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return "", err + } + + if resp.Code != 0 { + return "", serializer.NewErrorFromResponse(resp) + } + + return resp.Data.(string), nil +} + +// Slave remote download related +type ( + // Request body for creating tasks on slave node + CreateSlaveDownload struct { + NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"` + NodeSettingHash string `json:"node_setting_hash" binding:"required"` + Url string `json:"url" binding:"required"` + Options map[string]interface{} `json:"options"` + } + // Request body for get download task info from slave node + GetSlaveDownload struct { + Handle *downloader.TaskHandle `json:"handle" binding:"required"` + NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"` + NodeSettingHash string `json:"node_setting_hash" binding:"required"` + } + + // Request body for cancel download task on slave node + CancelSlaveDownload struct { + Handle *downloader.TaskHandle `json:"handle" binding:"required"` + NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"` + NodeSettingHash string `json:"node_setting_hash" binding:"required"` + } + + // Request body for selecting files to download on slave node + SetSlaveFilesToDownload struct { + Handle *downloader.TaskHandle `json:"handle" binding:"required"` + Args []*downloader.SetFileToDownloadArgs `json:"args" binding:"required"` + NodeSettingHash string `json:"node_setting_hash" binding:"required"` + NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"` + } + + TestSlaveDownload struct { + NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"` + NodeSettingHash string `json:"node_setting_hash" binding:"required"` + } +) + +// GetNodeSetting implements SlaveNodeSettingGetter interface +func (d *CreateSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) { + return d.NodeSetting, d.NodeSettingHash +} + +// GetNodeSetting implements SlaveNodeSettingGetter interface +func (d *GetSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) { + return d.NodeSetting, d.NodeSettingHash +} + +// GetNodeSetting implements SlaveNodeSettingGetter interface +func (d *CancelSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) { + return d.NodeSetting, d.NodeSettingHash +} + +// GetNodeSetting implements SlaveNodeSettingGetter interface +func (d *SetSlaveFilesToDownload) GetNodeSetting() (*types.NodeSetting, string) { + return d.NodeSetting, d.NodeSettingHash +} + +// GetNodeSetting implements SlaveNodeSettingGetter interface +func (d *TestSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) { + return d.NodeSetting, d.NodeSettingHash +} diff --git a/pkg/email/init.go b/pkg/email/init.go deleted file mode 100644 index fe83fe3b..00000000 --- a/pkg/email/init.go +++ /dev/null @@ -1,52 +0,0 @@ -package email - -import ( - "sync" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// Client 默认的邮件发送客户端 -var Client Driver - -// Lock 读写锁 -var Lock sync.RWMutex - -// Init 初始化 -func Init() { - util.Log().Debug("Initializing email sending queue...") - Lock.Lock() - defer Lock.Unlock() - - if Client != nil { - Client.Close() - } - - // 读取SMTP设置 - options := model.GetSettingByNames( - "fromName", - "fromAdress", - "smtpHost", - "replyTo", - "smtpUser", - "smtpPass", - "smtpEncryption", - ) - port := model.GetIntSetting("smtpPort", 25) - keepAlive := model.GetIntSetting("mail_keepalive", 30) - - client := NewSMTPClient(SMTPConfig{ - Name: options["fromName"], - Address: options["fromAdress"], - ReplyTo: options["replyTo"], - Host: options["smtpHost"], - Port: port, - User: options["smtpUser"], - Password: options["smtpPass"], - Keepalive: keepAlive, - Encryption: model.IsTrueVal(options["smtpEncryption"]), - }) - - Client = client -} diff --git a/pkg/email/mail.go b/pkg/email/mail.go index fbcbd683..63d6089b 100644 --- a/pkg/email/mail.go +++ b/pkg/email/mail.go @@ -1,8 +1,8 @@ package email import ( + "context" "errors" - "strings" ) // Driver 邮件发送驱动 @@ -10,7 +10,7 @@ type Driver interface { // Close 关闭驱动 Close() // Send 发送邮件 - Send(to, title, body string) error + Send(ctx context.Context, to, title, body string) error } var ( @@ -19,20 +19,3 @@ var ( // ErrNoActiveDriver 无可用邮件发送服务 ErrNoActiveDriver = errors.New("no avaliable email provider") ) - -// Send 发送邮件 -func Send(to, title, body string) error { - // 忽略通过QQ登录的邮箱 - if strings.HasSuffix(to, "@login.qq.com") { - return nil - } - - Lock.RLock() - defer Lock.RUnlock() - - if Client == nil { - return ErrNoActiveDriver - } - - return Client.Send(to, title, body) -} diff --git a/pkg/email/smtp.go b/pkg/email/smtp.go index 79f07855..1ea8926a 100644 --- a/pkg/email/smtp.go +++ b/pkg/email/smtp.go @@ -1,19 +1,27 @@ package email import ( + "context" "fmt" - "github.com/google/uuid" + "strings" "time" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" "github.com/go-mail/mail" + "github.com/gofrs/uuid" ) -// SMTP SMTP协议发送邮件 -type SMTP struct { +// SMTPPool SMTP协议发送邮件 +type SMTPPool struct { + // Deprecated Config SMTPConfig - ch chan *mail.Message + + config *setting.SMTP + ch chan *message chOpen bool + l logging.Logger } // SMTPConfig SMTP发送配置 @@ -26,14 +34,34 @@ type SMTPConfig struct { User string // 用户名 Password string // 密码 Encryption bool // 是否启用加密 - Keepalive int // SMTP 连接保留时长 + Keepalive int // SMTPPool 连接保留时长 +} + +type message struct { + msg *mail.Message + cid string + userID int +} + +// NewSMTPPool initializes a new SMTP based email sending queue. +func NewSMTPPool(config setting.Provider, logger logging.Logger) *SMTPPool { + client := &SMTPPool{ + config: config.SMTP(context.Background()), + ch: make(chan *message, 30), + chOpen: false, + l: logger, + } + + client.Init() + return client } // NewSMTPClient 新建SMTP发送队列 -func NewSMTPClient(config SMTPConfig) *SMTP { - client := &SMTP{ +// Deprecated +func NewSMTPClient(config SMTPConfig) *SMTPPool { + client := &SMTPPool{ Config: config, - ch: make(chan *mail.Message, 30), + ch: make(chan *message, 30), chOpen: false, } @@ -43,46 +71,57 @@ func NewSMTPClient(config SMTPConfig) *SMTP { } // Send 发送邮件 -func (client *SMTP) Send(to, title, body string) error { +func (client *SMTPPool) Send(ctx context.Context, to, title, body string) error { if !client.chOpen { - return ErrChanNotOpen + return fmt.Errorf("SMTP pool is closed") + } + + // 忽略通过QQ登录的邮箱 + if strings.HasSuffix(to, "@login.qq.com") { + return nil } + m := mail.NewMessage() - m.SetAddressHeader("From", client.Config.Address, client.Config.Name) - m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name) + m.SetAddressHeader("From", client.config.From, client.config.FromName) + m.SetAddressHeader("Reply-To", client.config.ReplyTo, client.config.FromName) m.SetHeader("To", to) m.SetHeader("Subject", title) - m.SetHeader("Message-ID", fmt.Sprintf("<%s@%s>", uuid.NewString(), "cloudreve")) + m.SetHeader("Message-ID", fmt.Sprintf("<%s@%s>", uuid.Must(uuid.NewV4()).String(), "cloudreve")) m.SetBody("text/html", body) - client.ch <- m + client.ch <- &message{ + msg: m, + cid: logging.CorrelationID(ctx).String(), + userID: inventory.UserIDFromContext(ctx), + } return nil } // Close 关闭发送队列 -func (client *SMTP) Close() { +func (client *SMTPPool) Close() { if client.ch != nil { close(client.ch) } } // Init 初始化发送队列 -func (client *SMTP) Init() { +func (client *SMTPPool) Init() { go func() { + client.l.Info("Initializing and starting SMTP email pool...") defer func() { if err := recover(); err != nil { client.chOpen = false - util.Log().Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err) + client.l.Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err) time.Sleep(time.Duration(10) * time.Second) client.Init() } }() - d := mail.NewDialer(client.Config.Host, client.Config.Port, client.Config.User, client.Config.Password) - d.Timeout = time.Duration(client.Config.Keepalive+5) * time.Second + d := mail.NewDialer(client.config.Host, client.config.Port, client.config.User, client.config.Password) + d.Timeout = time.Duration(client.config.Keepalive+5) * time.Second client.chOpen = true // 是否启用 SSL d.SSL = false - if client.Config.Encryption { + if client.config.ForceEncryption { d.SSL = true } d.StartTLSPolicy = mail.OpportunisticStartTLS @@ -94,26 +133,29 @@ func (client *SMTP) Init() { select { case m, ok := <-client.ch: if !ok { - util.Log().Debug("Email queue closing...") + client.l.Info("Email queue closing...") client.chOpen = false return } + if !open { if s, err = d.Dial(); err != nil { panic(err) } open = true } - if err := mail.Send(s, m); err != nil { - util.Log().Warning("Failed to send email: %s", err) + + l := client.l.CopyWithPrefix(fmt.Sprintf("[Cid: %s]", m.cid)) + if err := mail.Send(s, m.msg); err != nil { + l.Warning("Failed to send email: %s, Cid=%s", err, m.cid) } else { - util.Log().Debug("Email sent.") + l.Info("Email sent to %q, title: %q.", m.msg.GetHeader("To"), m.msg.GetHeader("Subject")) } // 长时间没有新邮件,则关闭SMTP连接 - case <-time.After(time.Duration(client.Config.Keepalive) * time.Second): + case <-time.After(time.Duration(client.config.Keepalive) * time.Second): if open { if err := s.Close(); err != nil { - util.Log().Warning("Failed to close SMTP connection: %s", err) + client.l.Warning("Failed to close SMTP connection: %s", err) } open = false } diff --git a/pkg/email/template.go b/pkg/email/template.go index cb9cb3a1..d732eec0 100644 --- a/pkg/email/template.go +++ b/pkg/email/template.go @@ -1,36 +1,125 @@ package email import ( + "context" "fmt" + "html/template" + "net/url" + "strings" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" ) -// NewActivationEmail 新建激活邮件 -func NewActivationEmail(userName, activateURL string) (string, string) { - options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template") - replace := map[string]string{ - "{siteTitle}": options["siteName"], - "{userName}": userName, - "{activationUrl}": activateURL, - "{siteUrl}": options["siteURL"], - "{siteSecTitle}": options["siteTitle"], - } - return fmt.Sprintf("【%s】注册激活", options["siteName"]), - util.Replace(replace, options["mail_activation_template"]) +type CommonContext struct { + SiteBasic *setting.SiteBasic + Logo *setting.Logo + SiteUrl string } -// NewResetEmail 新建重设密码邮件 -func NewResetEmail(userName, resetURL string) (string, string) { - options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_reset_pwd_template") - replace := map[string]string{ - "{siteTitle}": options["siteName"], - "{userName}": userName, - "{resetUrl}": resetURL, - "{siteUrl}": options["siteURL"], - "{siteSecTitle}": options["siteTitle"], - } - return fmt.Sprintf("【%s】密码重置", options["siteName"]), - util.Replace(replace, options["mail_reset_pwd_template"]) +// ResetContext used for variables in reset email +type ResetContext struct { + *CommonContext + User *ent.User + Url string +} + +// NewResetEmail generates reset email from template +func NewResetEmail(ctx context.Context, settings setting.Provider, user *ent.User, url string) (string, string, error) { + templates := settings.ResetEmailTemplate(ctx) + if len(templates) == 0 { + return "", "", fmt.Errorf("reset email template not configured") + } + + selected := selectTemplate(templates, user) + resetCtx := ResetContext{ + CommonContext: commonContext(ctx, settings), + User: user, + Url: url, + } + + tmpl, err := template.New("reset").Parse(selected.Body) + if err != nil { + return "", "", fmt.Errorf("failed to parse email template: %w", err) + } + + var res strings.Builder + err = tmpl.Execute(&res, resetCtx) + if err != nil { + return "", "", fmt.Errorf("failed to execute email template: %w", err) + } + + return fmt.Sprintf("[%s] %s", resetCtx.SiteBasic.Name, selected.Title), res.String(), nil +} + +// ActivationContext used for variables in activation email +type ActivationContext struct { + *CommonContext + User *ent.User + Url string +} + +// NewActivationEmail generates activation email from template +func NewActivationEmail(ctx context.Context, settings setting.Provider, user *ent.User, url string) (string, string, error) { + templates := settings.ActivationEmailTemplate(ctx) + if len(templates) == 0 { + return "", "", fmt.Errorf("activation email template not configured") + } + + selected := selectTemplate(templates, user) + activationCtx := ActivationContext{ + CommonContext: commonContext(ctx, settings), + User: user, + Url: url, + } + + tmpl, err := template.New("activation").Parse(selected.Body) + if err != nil { + return "", "", fmt.Errorf("failed to parse email template: %w", err) + } + + var res strings.Builder + err = tmpl.Execute(&res, activationCtx) + if err != nil { + return "", "", fmt.Errorf("failed to execute email template: %w", err) + } + + return fmt.Sprintf("[%s] %s", activationCtx.SiteBasic.Name, selected.Title), res.String(), nil +} + +func commonContext(ctx context.Context, settings setting.Provider) *CommonContext { + logo := settings.Logo(ctx) + siteUrl := settings.SiteURL(ctx) + res := &CommonContext{ + SiteBasic: settings.SiteBasic(ctx), + Logo: settings.Logo(ctx), + SiteUrl: siteUrl.String(), + } + + // Add site url if logo is not an url + if !strings.HasPrefix(logo.Light, "http") { + logoPath, _ := url.Parse(logo.Light) + res.Logo.Light = siteUrl.ResolveReference(logoPath).String() + } + + if !strings.HasPrefix(logo.Normal, "http") { + logoPath, _ := url.Parse(logo.Normal) + res.Logo.Normal = siteUrl.ResolveReference(logoPath).String() + } + + return res +} + +func selectTemplate(templates []setting.EmailTemplate, u *ent.User) setting.EmailTemplate { + selected := templates[0] + if u != nil { + for _, t := range templates { + if strings.EqualFold(t.Language, u.Settings.Language) { + selected = t + break + } + } + } + + return selected } diff --git a/pkg/filesystem/chunk/backoff/backoff.go b/pkg/filemanager/chunk/backoff/backoff.go similarity index 97% rename from pkg/filesystem/chunk/backoff/backoff.go rename to pkg/filemanager/chunk/backoff/backoff.go index 95cb1b5f..e4aab0fc 100644 --- a/pkg/filesystem/chunk/backoff/backoff.go +++ b/pkg/filemanager/chunk/backoff/backoff.go @@ -3,7 +3,7 @@ package backoff import ( "errors" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "net/http" "strconv" "time" diff --git a/pkg/filesystem/chunk/chunk.go b/pkg/filemanager/chunk/chunk.go similarity index 72% rename from pkg/filesystem/chunk/chunk.go rename to pkg/filemanager/chunk/chunk.go index cf790f68..174d785b 100644 --- a/pkg/filesystem/chunk/chunk.go +++ b/pkg/filemanager/chunk/chunk.go @@ -3,10 +3,11 @@ package chunk import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "io" "os" ) @@ -18,36 +19,38 @@ type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error // ChunkGroup manage groups of chunks type ChunkGroup struct { - file fsctx.FileHeader - chunkSize uint64 + file *fs.UploadRequest + chunkSize int64 backoff backoff.Backoff enableRetryBuffer bool + l logging.Logger - fileInfo *fsctx.UploadTaskInfo currentIndex int - chunkNum uint64 + chunkNum int64 bufferTemp *os.File + tempPath string } -func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup { +func NewChunkGroup(file *fs.UploadRequest, chunkSize int64, backoff backoff.Backoff, useBuffer bool, l logging.Logger, tempPath string) *ChunkGroup { c := &ChunkGroup{ file: file, chunkSize: chunkSize, backoff: backoff, - fileInfo: file.Info(), currentIndex: -1, enableRetryBuffer: useBuffer, + l: l, + tempPath: tempPath, } if c.chunkSize == 0 { - c.chunkSize = c.fileInfo.Size + c.chunkSize = c.file.Props.Size } - if c.fileInfo.Size == 0 { + if c.file.Props.Size == 0 { c.chunkNum = 1 } else { - c.chunkNum = c.fileInfo.Size / c.chunkSize - if c.fileInfo.Size%c.chunkSize != 0 { + c.chunkNum = c.file.Props.Size / c.chunkSize + if c.file.Props.Size%c.chunkSize != 0 { c.chunkNum++ } } @@ -71,7 +74,7 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { // If useBuffer is enabled, tee the reader to a temp file if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() { - c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern) + c.bufferTemp, _ = os.CreateTemp(util.DataPath(c.tempPath), bufferTempPattern) reader = io.TeeReader(reader, c.bufferTemp) } @@ -90,7 +93,7 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { return fmt.Errorf("failed to seek temp file back to chunk start: %w", err) } - util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name()) + c.l.Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name()) reader = io.NopCloser(c.bufferTemp) } } @@ -108,25 +111,25 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { } } - util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err) + c.l.Debug("Retrying chunk %d, last error: %s", c.currentIndex, err) return c.Process(processor) } return err } - util.Log().Debug("Chunk %d processed", c.currentIndex) + c.l.Debug("Chunk %d processed", c.currentIndex) return nil } // Start returns the byte index of current chunk func (c *ChunkGroup) Start() int64 { - return int64(uint64(c.Index()) * c.chunkSize) + return int64(int64(c.Index()) * c.chunkSize) } // Total returns the total length func (c *ChunkGroup) Total() int64 { - return int64(c.fileInfo.Size) + return int64(c.file.Props.Size) } // Num returns the total chunk number @@ -155,7 +158,7 @@ func (c *ChunkGroup) Next() bool { func (c *ChunkGroup) Length() int64 { contentLength := c.chunkSize if c.Index() == int(c.chunkNum-1) { - contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1) + contentLength = c.file.Props.Size - c.chunkSize*(c.chunkNum-1) } return int64(contentLength) diff --git a/pkg/filemanager/driver/cos/cos.go b/pkg/filemanager/driver/cos/cos.go new file mode 100644 index 00000000..186b6f30 --- /dev/null +++ b/pkg/filemanager/driver/cos/cos.go @@ -0,0 +1,588 @@ +package cos + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/google/go-querystring/query" + "github.com/samber/lo" + cossdk "github.com/tencentyun/cos-go-sdk-v5" +) + +// UploadPolicy 腾讯云COS上传策略 +type UploadPolicy struct { + Expiration string `json:"expiration"` + Conditions []interface{} `json:"conditions"` +} + +// MetaData 文件元信息 +type MetaData struct { + Size uint64 + CallbackKey string + CallbackURL string +} + +type urlOption struct { + Speed int64 `url:"x-cos-traffic-limit,omitempty"` + ContentDescription string `url:"response-content-disposition,omitempty"` + Exif *string `url:"exif,omitempty"` + CiProcess string `url:"ci-process,omitempty"` +} + +type ( + CosParts struct { + ETag string + PartNumber int + } +) + +// Driver 腾讯云COS适配器模板 +type Driver struct { + policy *ent.StoragePolicy + client *cossdk.Client + settings setting.Provider + config conf.ConfigProvider + httpClient request.Client + l logging.Logger + mime mime.MimeDetector + + chunkSize int64 +} + +const ( + // MultiPartUploadThreshold 服务端使用分片上传的阈值 + MultiPartUploadThreshold int64 = 5 * (1 << 30) // 5GB + + maxDeleteBatch = 1000 + chunkRetrySleep = time.Duration(5) * time.Second + overwriteOptionHeader = "x-cos-forbid-overwrite" + partNumberParam = "partNumber" + uploadIdParam = "uploadId" + contentTypeHeader = "Content-Type" + contentLengthHeader = "Content-Length" +) + +var ( + features = &boolset.BooleanSet{} +) + +func init() { + cossdk.SetNeedSignHeaders("host", false) + cossdk.SetNeedSignHeaders("origin", false) + boolset.Sets(map[driver.HandlerCapability]bool{ + driver.HandlerCapabilityUploadSentinelRequired: true, + }, features) +} + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 25 << 20 // 25 MB + } + + driver := &Driver{ + policy: policy, + settings: settings, + chunkSize: chunkSize, + config: config, + l: l, + mime: mime, + httpClient: request.NewClient(config, request.WithLogger(l)), + } + + u, err := url.Parse(policy.Server) + if err != nil { + return nil, fmt.Errorf("failed to parse COS bucket server url: %w", err) + } + driver.client = cossdk.NewClient(&cossdk.BaseURL{BucketURL: u}, &http.Client{ + Transport: &cossdk.AuthorizationTransport{ + SecretID: policy.AccessKey, + SecretKey: policy.SecretKey, + }, + }) + + return driver, nil +} + +// +//// List 列出COS文件 +//func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// // 初始化列目录参数 +// opt := &cossdk.BucketGetOptions{ +// Prefix: strings.TrimPrefix(base, "/"), +// EncodingType: "", +// MaxKeys: 1000, +// } +// // 是否为递归列出 +// if !recursive { +// opt.Delimiter = "/" +// } +// // 手动补齐结尾的slash +// if opt.Prefix != "" { +// opt.Prefix += "/" +// } +// +// var ( +// marker string +// objects []cossdk.Object +// commons []string +// ) +// +// for { +// res, _, err := handler.client.Bucket.Get(ctx, opt) +// if err != nil { +// return nil, err +// } +// objects = append(objects, res.Contents...) +// commons = append(commons, res.CommonPrefixes...) +// // 如果本次未列取完,则继续使用marker获取结果 +// marker = res.NextMarker +// // marker 为空时结果列取完毕,跳出 +// if marker == "" { +// break +// } +// } +// +// // 处理列取结果 +// res := make([]response.Object, 0, len(objects)+len(commons)) +// // 处理目录 +// for _, object := range commons { +// rel, err := filepath.Rel(opt.Prefix, object) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object), +// RelativePath: filepath.ToSlash(rel), +// Size: 0, +// IsDir: true, +// LastModify: time.Now(), +// }) +// } +// // 处理文件 +// for _, object := range objects { +// rel, err := filepath.Rel(opt.Prefix, object.Key) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object.Key), +// Source: object.Key, +// RelativePath: filepath.ToSlash(rel), +// Size: uint64(object.Size), +// IsDir: false, +// LastModify: time.Now(), +// }) +// } +// +// return res, nil +// +//} + +// CORS 创建跨域策略 +func (handler Driver) CORS() error { + _, err := handler.client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{ + Rules: []cossdk.BucketCORSRule{{ + AllowedMethods: []string{ + "GET", + "POST", + "PUT", + "DELETE", + "HEAD", + }, + AllowedOrigins: []string{"*"}, + AllowedHeaders: []string{"*"}, + MaxAgeSeconds: 3600, + ExposeHeaders: []string{"ETag"}, + }}, + }) + + return err +} + +// Get 获取文件 +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + opt := &cossdk.ObjectPutHeaderOptions{ + ContentType: mimeType, + XOptionHeader: &http.Header{ + overwriteOptionHeader: []string{fmt.Sprintf("%t", overwrite)}, + }, + } + + // 小文件直接上传 + if file.Props.Size < MultiPartUploadThreshold { + _, err := handler.client.Object.Put(ctx, file.Props.SavePath, file, &cossdk.ObjectPutOptions{ + ObjectPutHeaderOptions: opt, + }) + return err + } + + imur, _, err := handler.client.Object.InitiateMultipartUpload(ctx, file.Props.SavePath, &cossdk.InitiateMultipartUploadOptions{ + ObjectPutHeaderOptions: opt, + }) + + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{ + Max: handler.settings.ChunkRetryLimit(ctx), + Sleep: chunkRetrySleep, + }, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx)) + + parts := make([]CosParts, 0, chunks.Num()) + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + res, err := handler.client.Object.UploadPart(ctx, file.Props.SavePath, imur.UploadID, current.Index()+1, content, &cossdk.ObjectUploadPartOptions{ + ContentLength: current.Length(), + }) + if err == nil { + parts = append(parts, CosParts{ + ETag: res.Header.Get("ETag"), + PartNumber: current.Index() + 1, + }) + } + return err + } + + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + handler.cancelUpload(file.Props.SavePath, imur.UploadID) + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + _, _, err = handler.client.Object.CompleteMultipartUpload(ctx, file.Props.SavePath, imur.UploadID, &cossdk.CompleteMultipartUploadOptions{ + Parts: lo.Map(parts, func(v CosParts, i int) cossdk.Object { + return cossdk.Object{ + ETag: v.ETag, + PartNumber: v.PartNumber, + } + }), + XOptionHeader: &http.Header{ + overwriteOptionHeader: []string{fmt.Sprintf("%t", overwrite)}, + }, + }) + + if err != nil { + handler.cancelUpload(file.Props.SavePath, imur.UploadID) + } + + return err +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + groups := lo.Chunk(files, maxDeleteBatch) + failed := make([]string, 0) + var lastError error + for index, group := range groups { + handler.l.Debug("Process delete group #%d: %v", index, group) + res, _, err := handler.client.Object.DeleteMulti(ctx, + &cossdk.ObjectDeleteMultiOptions{ + Objects: lo.Map(group, func(item string, index int) cossdk.Object { + return cossdk.Object{Key: item} + }), + Quiet: true, + }) + if err != nil { + lastError = err + failed = append(failed, group...) + continue + } + + for _, v := range res.Errors { + handler.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key) + failed = append(failed, v.Key) + } + } + + if len(failed) > 0 && lastError == nil { + lastError = fmt.Errorf("failed to delete files: %v", failed) + } + + return failed, lastError +} + +// Thumb 获取文件缩略图 +func (handler Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + w, h := handler.settings.ThumbSize(ctx) + thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d", w, h) + + source, err := handler.signSourceURL( + ctx, + e.Source(), + expire, + &urlOption{}, + ) + if err != nil { + return "", err + } + + thumbURL, _ := url.Parse(source) + thumbQuery := thumbURL.Query() + thumbQuery.Add(thumbParam, "") + thumbURL.RawQuery = thumbQuery.Encode() + + return thumbURL.String(), nil +} + +// Source 获取外链URL +func (handler Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + // 添加各项设置 + options := urlOption{} + if args.Speed > 0 { + if args.Speed < 819200 { + args.Speed = 819200 + } + if args.Speed > 838860800 { + args.Speed = 838860800 + } + options.Speed = args.Speed + } + if args.IsDownload { + encodedFilename := url.PathEscape(args.DisplayName) + options.ContentDescription = fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + encodedFilename, encodedFilename) + } + + return handler.signSourceURL(ctx, e.Source(), args.Expire, &options) +} + +func (handler Driver) signSourceURL(ctx context.Context, path string, expire *time.Time, options *urlOption) (string, error) { + // 公有空间不需要签名 + if !handler.policy.IsPrivate || (handler.policy.Settings.SourceAuth && handler.policy.Settings.CustomProxy) { + file, err := url.Parse(handler.policy.Server) + if err != nil { + return "", err + } + + file.Path = path + + // 非签名URL不支持设置响应header + options.ContentDescription = "" + + optionQuery, err := query.Values(*options) + if err != nil { + return "", err + } + file.RawQuery = optionQuery.Encode() + + return file.String(), nil + } + + ttl := time.Duration(0) + if expire != nil { + ttl = time.Until(*expire) + } else { + // 20 years for permanent link + ttl = time.Duration(24) * time.Hour * 365 * 20 + } + + presignedURL, err := handler.client.Object.GetPresignedURL(ctx, http.MethodGet, path, + handler.policy.AccessKey, handler.policy.SecretKey, ttl, options) + if err != nil { + return "", err + } + + return presignedURL.String(), nil +} + +// Token 获取上传策略和认证Token +func (handler Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + // 在从机端创建上传会话 + uploadSession.ChunkSize = handler.chunkSize + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeCos, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + // 初始化分片上传 + opt := &cossdk.ObjectPutHeaderOptions{ + ContentType: mimeType, + XOptionHeader: &http.Header{ + overwriteOptionHeader: []string{"true"}, + }, + } + + imur, _, err := handler.client.Object.InitiateMultipartUpload(ctx, file.Props.SavePath, &cossdk.InitiateMultipartUploadOptions{ + ObjectPutHeaderOptions: opt, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize multipart upload: %w", err) + } + uploadSession.UploadID = imur.UploadID + + // 为每个分片签名上传 URL + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "") + urls := make([]string, chunks.Num()) + ttl := time.Until(uploadSession.Props.ExpireAt) + for chunks.Next() { + err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { + signedURL, err := handler.client.Object.GetPresignedURL( + ctx, + http.MethodPut, + file.Props.SavePath, + handler.policy.AccessKey, + handler.policy.SecretKey, + ttl, + &cossdk.PresignedURLOptions{ + Query: &url.Values{ + partNumberParam: []string{fmt.Sprintf("%d", c.Index()+1)}, + uploadIdParam: []string{imur.UploadID}, + }, + Header: &http.Header{ + contentTypeHeader: []string{"application/octet-stream"}, + contentLengthHeader: []string{fmt.Sprintf("%d", c.Length())}, + }, + }) + if err != nil { + return err + } + + urls[c.Index()] = signedURL.String() + return nil + }) + if err != nil { + return nil, err + } + } + + // 签名完成分片上传的URL + completeURL, err := handler.client.Object.GetPresignedURL( + ctx, + http.MethodPost, + file.Props.SavePath, + handler.policy.AccessKey, + handler.policy.SecretKey, + time.Until(uploadSession.Props.ExpireAt), + &cossdk.PresignedURLOptions{ + Query: &url.Values{ + uploadIdParam: []string{imur.UploadID}, + }, + Header: &http.Header{ + overwriteOptionHeader: []string{"true"}, + }, + }) + if err != nil { + return nil, err + } + + return &fs.UploadCredential{ + UploadID: imur.UploadID, + UploadURLs: urls, + CompleteURL: completeURL.String(), + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.chunkSize, + }, nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + _, err := handler.client.Object.AbortMultipartUpload(ctx, uploadSession.Props.SavePath, uploadSession.UploadID) + return err +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + if session.SentinelTaskID == 0 { + return nil + } + + // Make sure uploaded file size is correct + res, err := handler.client.Object.Head(ctx, session.Props.SavePath, &cossdk.ObjectHeadOptions{}) + if err != nil { + return fmt.Errorf("failed to get uploaded file size: %w", err) + } + + if res.ContentLength != session.Props.Size { + return serializer.NewError( + serializer.CodeMetaMismatch, + fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.ContentLength), + nil, + ) + } + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + mediaMetaExts := handler.policy.Settings.MediaMetaExts + if !handler.policy.Settings.NativeMediaProcessing { + mediaMetaExts = nil + } + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: mediaMetaExts, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: handler.policy.Settings.ThumbExts, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + ThumbMaxSize: handler.policy.Settings.ThumbMaxSize, + ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts, + } +} + +// Meta 获取文件信息 +func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) { + res, err := handler.client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{}) + if err != nil { + return nil, err + } + return &MetaData{ + Size: uint64(res.ContentLength), + CallbackKey: res.Header.Get("x-cos-meta-key"), + CallbackURL: res.Header.Get("x-cos-meta-callback"), + }, nil +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + if util.ContainsString(supportedImageExt, ext) { + return handler.extractImageMeta(ctx, path) + } + + return handler.extractStreamMeta(ctx, path) +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func (handler *Driver) cancelUpload(path, uploadId string) { + if _, err := handler.client.Object.AbortMultipartUpload(context.Background(), path, uploadId); err != nil { + handler.l.Warning("failed to abort multipart upload: %s", err) + } +} diff --git a/pkg/filemanager/driver/cos/media.go b/pkg/filemanager/driver/cos/media.go new file mode 100644 index 00000000..0a894a92 --- /dev/null +++ b/pkg/filemanager/driver/cos/media.go @@ -0,0 +1,294 @@ +package cos + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/samber/lo" + "math" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + mediaInfoTTL = time.Duration(10) * time.Minute + videoInfo = "videoinfo" +) + +var ( + supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff", "heic", "heif"} +) + +type ( + ImageProp struct { + Value string `json:"val"` + } + ImageInfo map[string]ImageProp + Error struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + RequestId string `xml:"RequestId"` + } + Video struct { + Index int `xml:"Index"` + CodecName string `xml:"CodecName"` + CodecLongName string `xml:"CodecLongName"` + CodecTimeBase string `xml:"CodecTimeBase"` + CodecTagString string `xml:"CodecTagString"` + CodecTag string `xml:"CodecTag"` + ColorPrimaries string `xml:"ColorPrimaries"` + ColorRange string `xml:"ColorRange"` + ColorTransfer string `xml:"ColorTransfer"` + Profile string `xml:"Profile"` + Width int `xml:"Width"` + Height int `xml:"Height"` + HasBFrame string `xml:"HasBFrame"` + RefFrames string `xml:"RefFrames"` + Sar string `xml:"Sar"` + Dar string `xml:"Dar"` + PixFormat string `xml:"PixFormat"` + FieldOrder string `xml:"FieldOrder"` + Level string `xml:"Level"` + Fps string `xml:"Fps"` + AvgFps string `xml:"AvgFps"` + Timebase string `xml:"Timebase"` + StartTime string `xml:"StartTime"` + Duration string `xml:"Duration"` + Bitrate string `xml:"Bitrate"` + NumFrames string `xml:"NumFrames"` + Language string `xml:"Language"` + } + Audio struct { + Index int `xml:"Index"` + CodecName string `xml:"CodecName"` + CodecLongName string `xml:"CodecLongName"` + CodecTimeBase string `xml:"CodecTimeBase"` + CodecTagString string `xml:"CodecTagString"` + CodecTag string `xml:"CodecTag"` + SampleFmt string `xml:"SampleFmt"` + SampleRate string `xml:"SampleRate"` + Channel string `xml:"Channel"` + ChannelLayout string `xml:"ChannelLayout"` + Timebase string `xml:"Timebase"` + StartTime string `xml:"StartTime"` + Duration string `xml:"Duration"` + Bitrate string `xml:"Bitrate"` + Language string `xml:"Language"` + } + Subtitle struct { + Index string `xml:"Index"` + Language string `xml:"Language"` + } + Response struct { + XMLName xml.Name `xml:"Response"` + MediaInfo struct { + Stream struct { + Video []Video `xml:"Video"` + Audio []Audio `xml:"Audio"` + Subtitle []Subtitle `xml:"Subtitle"` + } `xml:"Stream"` + Format struct { + NumStream string `xml:"NumStream"` + NumProgram string `xml:"NumProgram"` + FormatName string `xml:"FormatName"` + FormatLongName string `xml:"FormatLongName"` + StartTime string `xml:"StartTime"` + Duration string `xml:"Duration"` + Bitrate string `xml:"Bitrate"` + Size string `xml:"Size"` + } `xml:"Format"` + } `xml:"MediaInfo"` + } +) + +func (handler *Driver) extractStreamMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, &urlOption{CiProcess: videoInfo}) + if err != nil { + return nil, err + } + + var info Response + if err := xml.Unmarshal([]byte(resp), &info); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + streams := lo.Map(info.MediaInfo.Stream.Video, func(stream Video, index int) mediameta.Stream { + return mediameta.Stream{ + Index: stream.Index, + CodecName: stream.CodecName, + CodecLongName: stream.CodecLongName, + CodecType: "video", + Width: stream.Width, + Height: stream.Height, + Bitrate: stream.Bitrate, + } + }) + streams = append(streams, lo.Map(info.MediaInfo.Stream.Audio, func(stream Audio, index int) mediameta.Stream { + return mediameta.Stream{ + Index: stream.Index, + CodecName: stream.CodecName, + CodecLongName: stream.CodecLongName, + CodecType: "audio", + Bitrate: stream.Bitrate, + } + })...) + + metas := make([]driver.MediaMeta, 0) + metas = append(metas, mediameta.ProbeMetaTransform(&mediameta.FFProbeMeta{ + Format: &mediameta.Format{ + FormatName: info.MediaInfo.Format.FormatName, + FormatLongName: info.MediaInfo.Format.FormatLongName, + Duration: info.MediaInfo.Format.Duration, + Bitrate: info.MediaInfo.Format.Bitrate, + }, + Streams: streams, + })...) + + return nil, nil +} + +func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + exif := "" + resp, err := handler.extractMediaInfo(ctx, path, &urlOption{ + Exif: &exif, + }) + if err != nil { + return nil, err + } + + var imageInfo ImageInfo + if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + metas := make([]driver.MediaMeta, 0) + exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) { + return key, value.Value + }) + metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...) + metas = append(metas, parseGpsInfo(imageInfo)...) + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + + return metas, nil +} + +// extractMediaInfo Sends API calls to COS service to extract media info. +func (handler *Driver) extractMediaInfo(ctx context.Context, path string, opt *urlOption) (string, error) { + mediaInfoExpire := time.Now().Add(mediaInfoTTL) + thumbURL, err := handler.signSourceURL( + ctx, + path, + &mediaInfoExpire, + opt, + ) + if err != nil { + return "", fmt.Errorf("failed to sign media info url: %w", err) + } + + resp, err := handler.httpClient. + Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)). + CheckHTTPResponse(http.StatusOK). + GetResponseIgnoreErr() + if err != nil { + return "", handleCosError(resp, err) + } + + return resp, nil +} + +func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta { + latitude := imageInfo["GPSLatitude"] // 31deg 16.26808' + longitude := imageInfo["GPSLongitude"] // 120deg 42.91039' + latRef := imageInfo["GPSLatitudeRef"] // North + lonRef := imageInfo["GPSLongitudeRef"] // East + + // Make sure all value exist in map + if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" { + return nil + } + + lat := parseRawGPS(latitude.Value, latRef.Value) + lon := parseRawGPS(longitude.Value, lonRef.Value) + if !math.IsNaN(lat) && !math.IsNaN(lon) { + lat, lng := mediameta.NormalizeGPS(lat, lon) + return []driver.MediaMeta{{ + Key: mediameta.GpsLat, + Value: fmt.Sprintf("%f", lat), + }, { + Key: mediameta.GpsLng, + Value: fmt.Sprintf("%f", lng), + }} + } + + return nil +} + +func parseRawGPS(gpsStr string, ref string) float64 { + elem := strings.Split(gpsStr, " ") + if len(elem) < 1 { + return 0 + } + + var ( + deg float64 + minutes float64 + seconds float64 + ) + + deg = getGpsElemValue(elem[0]) + if len(elem) >= 2 { + minutes = getGpsElemValue(elem[1]) + } + if len(elem) >= 3 { + seconds = getGpsElemValue(elem[2]) + } + + decimal := deg + minutes/60.0 + seconds/3600.0 + + if ref == "S" || ref == "W" { + return -decimal + } + + return decimal +} + +func getGpsElemValue(elm string) float64 { + elements := strings.Split(elm, "/") + if len(elements) != 2 { + return 0 + } + + numerator, err := strconv.ParseFloat(elements[0], 64) + if err != nil { + return 0 + } + + denominator, err := strconv.ParseFloat(elements[1], 64) + if err != nil || denominator == 0 { + return 0 + } + + return numerator / denominator +} + +func handleCosError(resp string, originErr error) error { + if resp == "" { + return originErr + } + + var err Error + if err := xml.Unmarshal([]byte(resp), &err); err != nil { + return fmt.Errorf("failed to unmarshal cos error: %w", err) + } + + return fmt.Errorf("cos error: %s", err.Message) +} diff --git a/pkg/filemanager/driver/cos/scf.go b/pkg/filemanager/driver/cos/scf.go new file mode 100644 index 00000000..494893f6 --- /dev/null +++ b/pkg/filemanager/driver/cos/scf.go @@ -0,0 +1,118 @@ +package cos + +// TODO: revisit para error +const scfFunc = `# -*- coding: utf8 -*- +# SCF配置COS触发,向 Cloudreve 发送回调 +from qcloud_cos_v5 import CosConfig +from qcloud_cos_v5 import CosS3Client +from qcloud_cos_v5 import CosServiceError +from qcloud_cos_v5 import CosClientError +import sys +import logging +import requests + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +logging = logging.getLogger() + + +def main_handler(event, context): + logging.info("start main handler") + for record in event['Records']: + try: + if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']: + logging.info("Cannot find callback URL, skiped.") + return 'Success' + callback = record['cos']['cosObject']['meta']['x-cos-meta-callback'] + key = record['cos']['cosObject']['key'] + logging.info("Callback URL is " + callback) + + r = requests.get(callback) + print(r.text) + + + + except Exception as e: + print(e) + print('Error getting object {} callback url. '.format(key)) + raise e + return "Fail" + + return "Success" +` + +// +//// CreateSCF 创建回调云函数 +//func CreateSCF(policy *model.Policy, region string) error { +// // 初始化客户端 +// credential := common.NewCredential( +// policy.AccessKey, +// policy.SecretKey, +// ) +// cpf := profile.NewClientProfile() +// client, err := scf.NewClient(credential, region, cpf) +// if err != nil { +// return err +// } +// +// // 创建回调代码数据 +// buff := &bytes.Buffer{} +// bs64 := base64.NewEncoder(base64.StdEncoding, buff) +// zipWriter := zip.NewWriter(bs64) +// header := zip.FileHeader{ +// Name: "callback.py", +// Method: zip.Deflate, +// } +// writer, err := zipWriter.CreateHeader(&header) +// if err != nil { +// return err +// } +// _, err = io.Copy(writer, strings.NewReader(scfFunc)) +// zipWriter.Close() +// +// // 创建云函数 +// req := scf.NewCreateFunctionRequest() +// funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10) +// zipFileBytes, _ := ioutil.ReadAll(buff) +// zipFileStr := string(zipFileBytes) +// codeSource := "ZipFile" +// handler := "callback.main_handler" +// desc := "Cloudreve 用回调函数" +// timeout := int64(60) +// runtime := "Python3.6" +// req.FunctionName = &funcName +// req.Code = &scf.Code{ +// ZipFile: &zipFileStr, +// } +// req.Handler = &handler +// req.Description = &desc +// req.Timeout = &timeout +// req.Runtime = &runtime +// req.CodeSource = &codeSource +// +// _, err = client.CreateFunction(req) +// if err != nil { +// return err +// } +// +// time.Sleep(time.Duration(5) * time.Second) +// +// // 创建触发器 +// server, _ := url.Parse(policy.Server) +// triggerType := "cos" +// triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}` +// enable := "OPEN" +// +// trigger := scf.NewCreateTriggerRequest() +// trigger.FunctionName = &funcName +// trigger.TriggerName = &server.Host +// trigger.Type = &triggerType +// trigger.TriggerDesc = &triggerDesc +// trigger.Enable = &enable +// +// _, err = client.CreateTrigger(trigger) +// if err != nil { +// return err +// } +// +// return nil +//} diff --git a/pkg/filemanager/driver/handler.go b/pkg/filemanager/driver/handler.go new file mode 100644 index 00000000..e6bbddc1 --- /dev/null +++ b/pkg/filemanager/driver/handler.go @@ -0,0 +1,122 @@ +package driver + +import ( + "context" + "os" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" +) + +const ( + // HandlerCapabilityProxyRequired this handler requires Cloudreve's proxy to get file content + HandlerCapabilityProxyRequired HandlerCapability = iota + // HandlerCapabilityInboundGet this handler supports directly get file's RSCloser, usually + // indicates that the file is stored in the same machine as Cloudreve + HandlerCapabilityInboundGet + // HandlerCapabilityUploadSentinelRequired this handler does not support compliance callback mechanism, + // thus it requires Cloudreve's sentinel to guarantee the upload is under control. Cloudreve will try + // to delete the placeholder file and cancel the upload session if upload callback is not made after upload + // session expire. + HandlerCapabilityUploadSentinelRequired +) + +type ( + MetaType string + MediaMeta struct { + Key string `json:"key"` + Value string `json:"value"` + Type MetaType `json:"type"` + } + + HandlerCapability int + + GetSourceArgs struct { + Expire *time.Time + IsDownload bool + Speed int64 + DisplayName string + } + + // Handler 存储策略适配器 + Handler interface { + // 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭 + // 时,应取消上传并清理临时文件 + Put(ctx context.Context, file *fs.UploadRequest) error + + // 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误 + Delete(ctx context.Context, files ...string) ([]string, error) + + // Open physical files. Only implemented if HandlerCapabilityInboundGet capability is set. + // Returns file path and an os.File object. + Open(ctx context.Context, path string) (*os.File, error) + + // LocalPath returns the local path of a file. + // Only implemented if HandlerCapabilityInboundGet capability is set. + LocalPath(ctx context.Context, path string) string + + // Thumb returns the URL for a thumbnail of given entity. + Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) + + // 获取外链/下载地址, + // url - 站点本身地址, + // isDownload - 是否直接下载 + Source(ctx context.Context, e fs.Entity, args *GetSourceArgs) (string, error) + + // Token 获取有效期为ttl的上传凭证和签名 + Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) + + // CancelToken 取消已经创建的有状态上传凭证 + CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error + + // CompleteUpload completes a previously created upload session. + CompleteUpload(ctx context.Context, session *fs.UploadSession) error + + // List 递归列取远程端path路径下文件、目录,不包含path本身, + // 返回的对象路径以path作为起始根目录. + // recursive - 是否递归列出 + // List(ctx context.Context, path string, recursive bool) ([]response.Object, error) + + // Capabilities returns the capabilities of this handler + Capabilities() *Capabilities + + // MediaMeta extracts media metadata from the given file. + MediaMeta(ctx context.Context, path, ext string) ([]MediaMeta, error) + } + + Capabilities struct { + StaticFeatures *boolset.BooleanSet + // MaxSourceExpire indicates the maximum allowed expiration duration of a source URL + MaxSourceExpire time.Duration + // MinSourceExpire indicates the minimum allowed expiration duration of a source URL + MinSourceExpire time.Duration + // MediaMetaSupportedExts indicates the extensions of files that support media metadata. Empty list + // indicates that no file supports extracting media metadata. + MediaMetaSupportedExts []string + // GenerateMediaMeta indicates whether to generate media metadata using local generators. + MediaMetaProxy bool + // ThumbSupportedExts indicates the extensions of files that support thumbnail generation. Empty list + // indicates that no file supports thumbnail generation. + ThumbSupportedExts []string + // ThumbSupportAllExts indicates whether to generate thumbnails for all files, regardless of their extensions. + ThumbSupportAllExts bool + // ThumbMaxSize indicates the maximum allowed size of a thumbnail. 0 indicates that no limit is set. + ThumbMaxSize int64 + // ThumbProxy indicates whether to generate thumbnails using local generators. + ThumbProxy bool + } +) + +const ( + MetaTypeExif MetaType = "exif" + MediaTypeMusic MetaType = "music" + MetaTypeStreamMedia MetaType = "stream" +) + +type ForceUsePublicEndpointCtx struct{} + +// WithForcePublicEndpoint sets the context to force using public endpoint for supported storage policies. +func WithForcePublicEndpoint(ctx context.Context, value bool) context.Context { + return context.WithValue(ctx, ForceUsePublicEndpointCtx{}, value) +} diff --git a/pkg/filemanager/driver/local/entity.go b/pkg/filemanager/driver/local/entity.go new file mode 100644 index 00000000..d725ed45 --- /dev/null +++ b/pkg/filemanager/driver/local/entity.go @@ -0,0 +1,75 @@ +package local + +import ( + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "os" + "time" +) + +// NewLocalFileEntity creates a new local file entity. +func NewLocalFileEntity(t types.EntityType, src string) (fs.Entity, error) { + info, err := os.Stat(util.RelativePath(src)) + if err != nil { + return nil, err + } + + return &localFileEntity{ + t: t, + src: src, + size: info.Size(), + }, nil +} + +type localFileEntity struct { + t types.EntityType + src string + size int64 +} + +func (l *localFileEntity) ID() int { + return 0 +} + +func (l *localFileEntity) Type() types.EntityType { + return l.t +} + +func (l *localFileEntity) Size() int64 { + return l.size +} + +func (l *localFileEntity) UpdatedAt() time.Time { + return time.Now() +} + +func (l *localFileEntity) CreatedAt() time.Time { + return time.Now() +} + +func (l *localFileEntity) CreatedBy() *ent.User { + return nil +} + +func (l *localFileEntity) Source() string { + return l.src +} + +func (l *localFileEntity) ReferenceCount() int { + return 1 +} + +func (l *localFileEntity) PolicyID() int { + return 0 +} + +func (l *localFileEntity) UploadSessionID() *uuid.UUID { + return nil +} + +func (l *localFileEntity) Model() *ent.Entity { + return nil +} diff --git a/pkg/filemanager/driver/local/fallocate.go b/pkg/filemanager/driver/local/fallocate.go new file mode 100644 index 00000000..a3ff97c6 --- /dev/null +++ b/pkg/filemanager/driver/local/fallocate.go @@ -0,0 +1,11 @@ +//go:build !linux && !darwin +// +build !linux,!darwin + +package local + +import "os" + +// No-op on non-Linux/Darwin platforms. +func Fallocate(file *os.File, offset int64, length int64) error { + return nil +} diff --git a/pkg/filemanager/driver/local/fallocate_darwin.go b/pkg/filemanager/driver/local/fallocate_darwin.go new file mode 100644 index 00000000..05ba321d --- /dev/null +++ b/pkg/filemanager/driver/local/fallocate_darwin.go @@ -0,0 +1,27 @@ +package local + +import ( + "os" + "syscall" + "unsafe" +) + +func Fallocate(file *os.File, offset int64, length int64) error { + var fst syscall.Fstore_t + + fst.Flags = syscall.F_ALLOCATECONTIG + fst.Posmode = syscall.F_PREALLOCATE + fst.Offset = 0 + fst.Length = offset + length + fst.Bytesalloc = 0 + + // Check https://lists.apple.com/archives/darwin-dev/2007/Dec/msg00040.html + _, _, err := syscall.Syscall(syscall.SYS_FCNTL, file.Fd(), syscall.F_PREALLOCATE, uintptr(unsafe.Pointer(&fst))) + if err != syscall.Errno(0x0) { + fst.Flags = syscall.F_ALLOCATEALL + // Ignore the return value + _, _, _ = syscall.Syscall(syscall.SYS_FCNTL, file.Fd(), syscall.F_PREALLOCATE, uintptr(unsafe.Pointer(&fst))) + } + + return syscall.Ftruncate(int(file.Fd()), fst.Length) +} diff --git a/pkg/filemanager/driver/local/fallocate_linux.go b/pkg/filemanager/driver/local/fallocate_linux.go new file mode 100644 index 00000000..9c247dfc --- /dev/null +++ b/pkg/filemanager/driver/local/fallocate_linux.go @@ -0,0 +1,14 @@ +package local + +import ( + "os" + "syscall" +) + +func Fallocate(file *os.File, offset int64, length int64) error { + if length == 0 { + return nil + } + + return syscall.Fallocate(int(file.Fd()), 0, offset, length) +} diff --git a/pkg/filemanager/driver/local/local.go b/pkg/filemanager/driver/local/local.go new file mode 100644 index 00000000..88949b63 --- /dev/null +++ b/pkg/filemanager/driver/local/local.go @@ -0,0 +1,301 @@ +package local + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" +) + +const ( + Perm = 0744 +) + +var ( + capabilities = &driver.Capabilities{ + StaticFeatures: &boolset.BooleanSet{}, + MediaMetaProxy: true, + ThumbProxy: true, + } +) + +func init() { + boolset.Sets(map[driver.HandlerCapability]bool{ + driver.HandlerCapabilityProxyRequired: true, + driver.HandlerCapabilityInboundGet: true, + }, capabilities.StaticFeatures) +} + +// Driver 本地策略适配器 +type Driver struct { + Policy *ent.StoragePolicy + httpClient request.Client + l logging.Logger + config conf.ConfigProvider +} + +// New constructs a new local driver +func New(p *ent.StoragePolicy, l logging.Logger, config conf.ConfigProvider) *Driver { + return &Driver{ + Policy: p, + l: l, + httpClient: request.NewClient(config, request.WithLogger(l)), + config: config, + } +} + +//// List 递归列取给定物理路径下所有文件 +//func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { +// var res []response.Object +// +// // 取得起始路径 +// root := util.RelativePath(filepath.FromSlash(path)) +// +// // 开始遍历路径下的文件、目录 +// err := filepath.Walk(root, +// func(path string, info os.FileInfo, err error) error { +// // 跳过根目录 +// if path == root { +// return nil +// } +// +// if err != nil { +// util.Log().Warning("Failed to walk folder %q: %s", path, err) +// return filepath.SkipDir +// } +// +// // 将遍历对象的绝对路径转换为相对路径 +// rel, err := filepath.Rel(root, path) +// if err != nil { +// return err +// } +// +// res = append(res, response.Object{ +// Name: info.Name(), +// RelativePath: filepath.ToSlash(rel), +// Source: path, +// Size: uint64(info.Size()), +// IsDir: info.IsDir(), +// LastModify: info.ModTime(), +// }) +// +// // 如果非递归,则不步入目录 +// if !recursive && info.IsDir() { +// return filepath.SkipDir +// } +// +// return nil +// }) +// +// return res, err +//} + +// Get 获取文件内容 +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + // 打开文件 + file, err := os.Open(handler.LocalPath(ctx, path)) + if err != nil { + handler.l.Debug("Failed to open file: %s", err) + return nil, err + } + + return file, nil +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return util.RelativePath(filepath.FromSlash(path)) +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + dst := util.RelativePath(filepath.FromSlash(file.Props.SavePath)) + + // 如果非 Overwrite,则检查是否有重名冲突 + if file.Mode&fs.ModeOverwrite != fs.ModeOverwrite { + if util.Exists(dst) { + handler.l.Warning("File with the same name existed or unavailable: %s", dst) + return errors.New("file with the same name existed or unavailable") + } + } + + if err := handler.prepareFileDirectory(dst); err != nil { + return err + } + + openMode := os.O_CREATE | os.O_RDWR + if file.Mode&fs.ModeOverwrite == fs.ModeOverwrite && file.Offset == 0 { + openMode |= os.O_TRUNC + } + + out, err := os.OpenFile(dst, openMode, Perm) + if err != nil { + handler.l.Warning("Failed to open or create file: %s", err) + return err + } + defer out.Close() + + stat, err := out.Stat() + if err != nil { + handler.l.Warning("Failed to read file info: %s", err) + return err + } + + if stat.Size() < file.Offset { + return errors.New("size of unfinished uploaded chunks is not as expected") + } + + if _, err := out.Seek(file.Offset, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek to desired offset %d: %s", file.Offset, err) + } + + // 写入文件内容 + _, err = io.Copy(out, file) + return err +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + deleteFailed := make([]string, 0, len(files)) + var retErr error + + for _, value := range files { + filePath := util.RelativePath(filepath.FromSlash(value)) + if util.Exists(filePath) { + err := os.Remove(filePath) + if err != nil { + handler.l.Warning("Failed to delete file: %s", err) + retErr = err + deleteFailed = append(deleteFailed, value) + } + } + + //// 尝试删除文件的缩略图(如果有) + //_ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"))) + } + + return deleteFailed, retErr +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + return "", errors.New("not implemented") +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + return "", errors.New("not implemented") +} + +// Token 获取上传策略和认证Token,本地策略直接返回空值 +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + if file.Mode&fs.ModeOverwrite != fs.ModeOverwrite && util.Exists(uploadSession.Props.SavePath) { + return nil, errors.New("placeholder file already exist") + } + + dst := util.RelativePath(filepath.FromSlash(uploadSession.Props.SavePath)) + if err := handler.prepareFileDirectory(dst); err != nil { + return nil, fmt.Errorf("failed to prepare file directory: %w", err) + } + + f, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, Perm) + if err != nil { + return nil, fmt.Errorf("failed to create placeholder file: %w", err) + } + + // Preallocate disk space + defer f.Close() + if handler.Policy.Settings.PreAllocate { + if err := Fallocate(f, 0, uploadSession.Props.Size); err != nil { + handler.l.Warning("Failed to preallocate file: %s", err) + } + } + + return &fs.UploadCredential{ + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.Policy.Settings.ChunkSize, + }, nil +} + +func (h *Driver) prepareFileDirectory(dst string) error { + basePath := filepath.Dir(dst) + if !util.Exists(basePath) { + err := os.MkdirAll(basePath, Perm) + if err != nil { + h.l.Warning("Failed to create directory: %s", err) + return err + } + } + + return nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + return nil +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + if session.Callback == "" { + return nil + } + + if session.Policy.Edges.Node == nil { + return serializer.NewError(serializer.CodeCallbackError, "Node not found", nil) + } + + // If callback is set, indicating this handler is used in slave node as a shadowed handler for remote policy, + // we need to send callback request to master node. + resp := handler.httpClient.Request( + "POST", + session.Callback, + nil, + request.WithTimeout(time.Duration(handler.config.Slave().CallbackTimeout)*time.Second), + request.WithCredential( + auth.HMACAuth{[]byte(session.Policy.Edges.Node.SlaveKey)}, + int64(handler.config.Slave().SignatureTTL), + ), + request.WithContext(ctx), + request.WithCorrelationID(), + ) + + if resp.Err != nil { + return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err) + } + + // 解析回调服务端响应 + res, err := resp.DecodeResponse() + if err != nil { + msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode) + return serializer.NewError(serializer.CodeCallbackError, msg, err) + } + + if res.Code != 0 { + return serializer.NewError(res.Code, res.Msg, errors.New(res.Error)) + } + + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + return capabilities +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + return nil, errors.New("not implemented") +} diff --git a/pkg/filemanager/driver/obs/media.go b/pkg/filemanager/driver/obs/media.go new file mode 100644 index 00000000..b7d75e41 --- /dev/null +++ b/pkg/filemanager/driver/obs/media.go @@ -0,0 +1,137 @@ +package obs + +import ( + "context" + "encoding/json" + "fmt" + "math" + "net/http" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/huaweicloud/huaweicloud-sdk-go-obs/obs" + "github.com/samber/lo" +) + +func (d *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + thumbURL, err := d.signSourceURL(&obs.CreateSignedUrlInput{ + Method: obs.HttpMethodGet, + Bucket: d.policy.BucketName, + Key: path, + Expires: int(mediaInfoTTL.Seconds()), + QueryParams: map[string]string{ + imageProcessHeader: imageInfoProcessor, + }, + }) + + if err != nil { + return nil, fmt.Errorf("failed to sign media info url: %w", err) + } + + resp, err := d.httpClient. + Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)). + CheckHTTPResponse(http.StatusOK). + GetResponseIgnoreErr() + if err != nil { + return nil, handleJsonError(resp, err) + } + + var imageInfo map[string]any + if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + imageInfoMap := lo.MapEntries(imageInfo, func(k string, v any) (string, string) { + if vStr, ok := v.(string); ok { + return strings.TrimPrefix(k, "exif:"), vStr + } + + return k, fmt.Sprintf("%v", v) + }) + metas := make([]driver.MediaMeta, 0) + metas = append(metas, mediameta.ExtractExifMap(imageInfoMap, time.Time{})...) + metas = append(metas, parseGpsInfo(imageInfoMap)...) + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + return metas, nil +} + +func parseGpsInfo(imageInfo map[string]string) []driver.MediaMeta { + latitude := imageInfo["GPSLatitude"] // 31/1, 162680820/10000000, 0/1 + longitude := imageInfo["GPSLongitude"] // 120/1, 429103939/10000000, 0/1 + latRef := imageInfo["GPSLatitudeRef"] // N + lonRef := imageInfo["GPSLongitudeRef"] // E + + // Make sure all value exist in map + if latitude == "" || longitude == "" || latRef == "" || lonRef == "" { + return nil + } + + lat := parseRawGPS(latitude, latRef) + lon := parseRawGPS(longitude, lonRef) + if !math.IsNaN(lat) && !math.IsNaN(lon) { + lat, lng := mediameta.NormalizeGPS(lat, lon) + return []driver.MediaMeta{{ + Key: mediameta.GpsLat, + Value: fmt.Sprintf("%f", lat), + }, { + Key: mediameta.GpsLng, + Value: fmt.Sprintf("%f", lng), + }} + } + + return nil +} + +func parseRawGPS(gpsStr string, ref string) float64 { + elem := strings.Split(gpsStr, ", ") + if len(elem) < 1 { + return 0 + } + + var ( + deg float64 + minutes float64 + seconds float64 + ) + + deg = getGpsElemValue(elem[0]) + if len(elem) >= 2 { + minutes = getGpsElemValue(elem[1]) + } + if len(elem) >= 3 { + seconds = getGpsElemValue(elem[2]) + } + + decimal := deg + minutes/60.0 + seconds/3600.0 + + if ref == "S" || ref == "W" { + return -decimal + } + + return decimal +} + +func getGpsElemValue(elm string) float64 { + elements := strings.Split(elm, "/") + if len(elements) != 2 { + return 0 + } + + numerator, err := strconv.ParseFloat(elements[0], 64) + if err != nil { + return 0 + } + + denominator, err := strconv.ParseFloat(elements[1], 64) + if err != nil || denominator == 0 { + return 0 + } + + return numerator / denominator +} diff --git a/pkg/filemanager/driver/obs/obs.go b/pkg/filemanager/driver/obs/obs.go new file mode 100644 index 00000000..93313f5b --- /dev/null +++ b/pkg/filemanager/driver/obs/obs.go @@ -0,0 +1,513 @@ +package obs + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/huaweicloud/huaweicloud-sdk-go-obs/obs" + "github.com/samber/lo" +) + +const ( + chunkRetrySleep = time.Duration(5) * time.Second + maxDeleteBatch = 1000 + imageProcessHeader = "x-image-process" + trafficLimitHeader = "x-obs-traffic-limit" + partNumberParam = "partNumber" + callbackParam = "x-obs-callback" + uploadIdParam = "uploadId" + mediaInfoTTL = time.Duration(10) * time.Minute + imageInfoProcessor = "image/info" + + // MultiPartUploadThreshold 服务端使用分片上传的阈值 + MultiPartUploadThreshold int64 = 5 << 30 // 5GB +) + +var ( + features = &boolset.BooleanSet{} +) + +type ( + CallbackPolicy struct { + CallbackURL string `json:"callbackUrl"` + CallbackBody string `json:"callbackBody"` + CallbackBodyType string `json:"callbackBodyType"` + } + JsonError struct { + Message string `json:"message"` + Code string `json:"code"` + } +) + +// Driver Huawei Cloud OBS driver +type Driver struct { + policy *ent.StoragePolicy + chunkSize int64 + + settings setting.Provider + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + httpClient request.Client + obs *obs.ObsClient +} + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 25 << 20 // 25 MB + } + + driver := &Driver{ + policy: policy, + settings: settings, + chunkSize: chunkSize, + config: config, + l: l, + mime: mime, + httpClient: request.NewClient(config, request.WithLogger(l)), + } + + useCname := false + if policy.Settings != nil && policy.Settings.UseCname { + useCname = true + } + + obsClient, err := obs.New(policy.AccessKey, policy.SecretKey, policy.Server, obs.WithSignature(obs.SignatureObs), obs.WithCustomDomainName(useCname)) + if err != nil { + return nil, err + } + + driver.obs = obsClient + return driver, nil +} + +func (d *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + if !overwrite { + // Check for duplicated file + if _, err := d.obs.HeadObject(&obs.HeadObjectInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + }, obs.WithRequestContext(ctx)); err == nil { + return fs.ErrFileExisted + } + } + + mimeType := file.Props.MimeType + if mimeType == "" { + d.mime.TypeByName(file.Props.Uri.Name()) + } + + // 小文件直接上传 + if file.Props.Size < MultiPartUploadThreshold { + _, err := d.obs.PutObject(&obs.PutObjectInput{ + PutObjectBasicInput: obs.PutObjectBasicInput{ + ObjectOperationInput: obs.ObjectOperationInput{ + Key: file.Props.SavePath, + Bucket: d.policy.BucketName, + }, + HttpHeader: obs.HttpHeader{ + ContentType: mimeType, + }, + ContentLength: file.Props.Size, + }, + Body: file, + }, obs.WithRequestContext(ctx)) + return err + } + + // 超过阈值时使用分片上传 + imur, err := d.obs.InitiateMultipartUpload(&obs.InitiateMultipartUploadInput{ + ObjectOperationInput: obs.ObjectOperationInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + }, + HttpHeader: obs.HttpHeader{ + ContentType: d.mime.TypeByName(file.Props.Uri.Name()), + }, + }, obs.WithRequestContext(ctx)) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + chunks := chunk.NewChunkGroup(file, d.chunkSize, &backoff.ConstantBackoff{ + Max: d.settings.ChunkRetryLimit(ctx), + Sleep: chunkRetrySleep, + }, d.settings.UseChunkBuffer(ctx), d.l, d.settings.TempPath(ctx)) + + parts := make([]*obs.UploadPartOutput, 0, chunks.Num()) + + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + part, err := d.obs.UploadPart(&obs.UploadPartInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + PartNumber: current.Index() + 1, + UploadId: imur.UploadId, + Body: content, + SourceFile: "", + PartSize: current.Length(), + }, obs.WithRequestContext(ctx)) + if err == nil { + parts = append(parts, part) + } + return err + } + + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + d.cancelUpload(file.Props.SavePath, imur) + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + _, err = d.obs.CompleteMultipartUpload(&obs.CompleteMultipartUploadInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + UploadId: imur.UploadId, + Parts: lo.Map(parts, func(part *obs.UploadPartOutput, i int) obs.Part { + return obs.Part{ + PartNumber: i + 1, + ETag: part.ETag, + } + }), + }, obs.WithRequestContext(ctx)) + if err != nil { + d.cancelUpload(file.Props.SavePath, imur) + } + + return err +} + +func (d *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + groups := lo.Chunk(files, maxDeleteBatch) + failed := make([]string, 0) + var lastError error + for index, group := range groups { + d.l.Debug("Process delete group #%d: %v", index, group) + // 删除文件 + delRes, err := d.obs.DeleteObjects(&obs.DeleteObjectsInput{ + Bucket: d.policy.BucketName, + Quiet: true, + Objects: lo.Map(group, func(item string, index int) obs.ObjectToDelete { + return obs.ObjectToDelete{ + Key: item, + } + }), + }, obs.WithRequestContext(ctx)) + if err != nil { + failed = append(failed, group...) + lastError = err + continue + } + + for _, v := range delRes.Errors { + d.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key) + failed = append(failed, v.Key) + } + } + + if len(failed) > 0 && lastError == nil { + lastError = fmt.Errorf("failed to delete files: %v", failed) + } + + return failed, lastError +} + +func (d *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +func (d *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func (d *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + w, h := d.settings.ThumbSize(ctx) + thumbURL, err := d.signSourceURL(&obs.CreateSignedUrlInput{ + Method: obs.HttpMethodGet, + Bucket: d.policy.BucketName, + Key: e.Source(), + Expires: int(time.Until(*expire).Seconds()), + QueryParams: map[string]string{ + imageProcessHeader: fmt.Sprintf("image/resize,m_lfit,w_%d,h_%d", w, h), + }, + }) + + if err != nil { + return "", err + } + + return thumbURL, nil +} + +func (d *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + params := make(map[string]string) + if args.IsDownload { + encodedFilename := url.PathEscape(args.DisplayName) + params["response-content-disposition"] = fmt.Sprintf("attachment; filename=\"%s\"; filename*=UTF-8''%s", + args.DisplayName, encodedFilename) + } + + expires := 86400 * 265 * 20 + if args.Expire != nil { + expires = int(time.Until(*args.Expire).Seconds()) + } + + if args.Speed > 0 { + // Byte 转换为 bit + args.Speed *= 8 + + // OSS对速度值有范围限制 + if args.Speed < 819200 { + args.Speed = 819200 + } + if args.Speed > 838860800 { + args.Speed = 838860800 + } + } + + if args.Speed > 0 { + params[trafficLimitHeader] = strconv.FormatInt(args.Speed, 10) + } + + return d.signSourceURL(&obs.CreateSignedUrlInput{ + Method: obs.HttpMethodGet, + Bucket: d.policy.BucketName, + Key: e.Source(), + Expires: expires, + QueryParams: params, + }) +} + +func (d *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // Check for duplicated file + if _, err := d.obs.HeadObject(&obs.HeadObjectInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + }, obs.WithRequestContext(ctx)); err == nil { + return nil, fs.ErrFileExisted + } + + // 生成回调地址 + siteURL := d.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + // 在从机端创建上传会话 + uploadSession.ChunkSize = d.chunkSize + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeObs, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + // 回调策略 + callbackPolicy := CallbackPolicy{ + CallbackURL: uploadSession.Callback, + CallbackBody: `{"name":${key},"source_name":${fname},"size":${size}}`, + CallbackBodyType: "application/json", + } + + callbackPolicyJSON, err := json.Marshal(callbackPolicy) + if err != nil { + return nil, fmt.Errorf("failed to encode callback policy: %w", err) + } + callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON) + + mimeType := file.Props.MimeType + if mimeType == "" { + d.mime.TypeByName(file.Props.Uri.Name()) + } + + imur, err := d.obs.InitiateMultipartUpload(&obs.InitiateMultipartUploadInput{ + ObjectOperationInput: obs.ObjectOperationInput{ + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + }, + HttpHeader: obs.HttpHeader{ + ContentType: mimeType, + }, + }, obs.WithRequestContext(ctx)) + if err != nil { + return nil, fmt.Errorf("failed to initialize multipart upload: %w", err) + } + uploadSession.UploadID = imur.UploadId + + // 为每个分片签名上传 URL + chunks := chunk.NewChunkGroup(file, d.chunkSize, &backoff.ConstantBackoff{}, false, d.l, "") + urls := make([]string, chunks.Num()) + ttl := int64(time.Until(uploadSession.Props.ExpireAt).Seconds()) + for chunks.Next() { + err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { + signedURL, err := d.obs.CreateSignedUrl(&obs.CreateSignedUrlInput{ + Method: obs.HttpMethodPut, + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + QueryParams: map[string]string{ + partNumberParam: strconv.Itoa(c.Index() + 1), + uploadIdParam: uploadSession.UploadID, + }, + Expires: int(ttl), + Headers: map[string]string{ + "Content-Length": strconv.FormatInt(c.Length(), 10), + "Content-Type": "application/octet-stream", + }, //TODO: Validate +1 + }) + if err != nil { + return err + } + + urls[c.Index()] = signedURL.SignedUrl + return nil + }) + if err != nil { + return nil, err + } + } + + // 签名完成分片上传的URL + completeURL, err := d.obs.CreateSignedUrl(&obs.CreateSignedUrlInput{ + Method: obs.HttpMethodPost, + Bucket: d.policy.BucketName, + Key: file.Props.SavePath, + QueryParams: map[string]string{ + uploadIdParam: uploadSession.UploadID, + callbackParam: callbackPolicyEncoded, + }, + Headers: map[string]string{ + "Content-Type": "application/octet-stream", + }, + Expires: int(ttl), + }) + if err != nil { + return nil, err + } + + return &fs.UploadCredential{ + UploadID: imur.UploadId, + UploadURLs: urls, + CompleteURL: completeURL.SignedUrl, + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: d.chunkSize, + }, nil +} + +func (d *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + _, err := d.obs.AbortMultipartUpload(&obs.AbortMultipartUploadInput{ + Bucket: d.policy.BucketName, + Key: uploadSession.Props.SavePath, + UploadId: uploadSession.UploadID, + }, obs.WithRequestContext(ctx)) + return err +} + +func (d *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + return nil +} + +//func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { +// return nil, errors.New("not implemented") +//} + +func (d *Driver) Capabilities() *driver.Capabilities { + mediaMetaExts := d.policy.Settings.MediaMetaExts + if !d.policy.Settings.NativeMediaProcessing { + mediaMetaExts = nil + } + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: mediaMetaExts, + MediaMetaProxy: d.policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: d.policy.Settings.ThumbExts, + ThumbProxy: d.policy.Settings.ThumbGeneratorProxy, + ThumbSupportAllExts: d.policy.Settings.ThumbSupportAllExts, + ThumbMaxSize: d.policy.Settings.ThumbMaxSize, + } +} + +// CORS 创建跨域策略 +func (d *Driver) CORS() error { + _, err := d.obs.SetBucketCors(&obs.SetBucketCorsInput{ + Bucket: d.policy.BucketName, + BucketCors: obs.BucketCors{ + CorsRules: []obs.CorsRule{ + { + AllowedOrigin: []string{"*"}, + AllowedMethod: []string{ + "GET", + "POST", + "PUT", + "DELETE", + "HEAD", + }, + ExposeHeader: []string{"Etag"}, + AllowedHeader: []string{"*"}, + MaxAgeSeconds: 3600, + }, + }, + }, + }) + return err +} + +func (d *Driver) cancelUpload(path string, imur *obs.InitiateMultipartUploadOutput) { + if _, err := d.obs.AbortMultipartUpload(&obs.AbortMultipartUploadInput{ + Bucket: d.policy.BucketName, + Key: path, + UploadId: imur.UploadId, + }); err != nil { + d.l.Warning("failed to abort multipart upload: %s", err) + } +} + +func (handler *Driver) signSourceURL(input *obs.CreateSignedUrlInput) (string, error) { + signedURL, err := handler.obs.CreateSignedUrl(input) + if err != nil { + return "", err + } + + finalURL, err := url.Parse(signedURL.SignedUrl) + if err != nil { + return "", err + } + + // 公有空间替换掉Key及不支持的头 + if !handler.policy.IsPrivate { + query := finalURL.Query() + query.Del("AccessKeyId") + query.Del("Signature") + finalURL.RawQuery = query.Encode() + } + return finalURL.String(), nil +} + +func handleJsonError(resp string, originErr error) error { + if resp == "" { + return originErr + } + + var err JsonError + if err := json.Unmarshal([]byte(resp), &err); err != nil { + return fmt.Errorf("failed to unmarshal cos error: %w", err) + } + + return fmt.Errorf("obs error: %s", err.Message) +} diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filemanager/driver/onedrive/api.go similarity index 65% rename from pkg/filesystem/driver/onedrive/api.go rename to pkg/filemanager/driver/onedrive/api.go index 56abbaa9..4b73b055 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filemanager/driver/onedrive/api.go @@ -4,23 +4,16 @@ import ( "context" "encoding/json" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/request" "io" "net/http" "net/url" "path" - "strconv" "strings" "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" ) const ( @@ -35,6 +28,8 @@ const ( notFoundError = "itemNotFound" ) +type RetryCtx struct{} + // GetSourcePath 获取文件的绝对路径 func (info *FileInfo) GetSourcePath() string { res, err := url.PathUnescape(info.ParentReference.Path) @@ -51,19 +46,19 @@ func (info *FileInfo) GetSourcePath() string { ) } -func (client *Client) getRequestURL(api string, opts ...Option) string { +func (client *client) getRequestURL(api string, opts ...Option) string { options := newDefaultOption() for _, o := range opts { o.apply(options) } - base, _ := url.Parse(client.Endpoints.EndpointURL) + base, _ := url.Parse(client.endpoints.endpointURL) if base == nil { return "" } if options.useDriverResource { - base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api) + base.Path = path.Join(base.Path, client.endpoints.driverResource, api) } else { base.Path = path.Join(base.Path, api) } @@ -72,7 +67,7 @@ func (client *Client) getRequestURL(api string, opts ...Option) string { } // ListChildren 根据路径列取子对象 -func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) { +func (client *client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) { var requestURL string dst := strings.TrimPrefix(path, "/") if dst == "" { @@ -84,14 +79,14 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200) if err != nil { retried := 0 - if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok { + if v, ok := ctx.Value(RetryCtx{}).(int); ok { retried = v } if retried < ListRetry { retried++ - util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err) + client.l.Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err) time.Sleep(time.Duration(5) * time.Second) - return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path) + return client.ListChildren(context.WithValue(ctx, RetryCtx{}, retried), path) } return nil, err } @@ -109,7 +104,7 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo } // Meta 根据资源ID或文件路径获取文件元信息 -func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) { +func (client *client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) { var requestURL string if id != "" { requestURL = client.getRequestURL("items/" + id) @@ -137,7 +132,7 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn } // CreateUploadSession 创建分片上传会话 -func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) { +func (client *client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) { options := newDefaultOption() for _, o := range opts { o.apply(options) @@ -170,7 +165,7 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts } // GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID -func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) { +func (client *client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) { siteUrlParsed, err := url.Parse(siteUrl) if err != nil { return "", err @@ -197,7 +192,7 @@ func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (strin } // GetUploadSessionStatus 查询上传会话状态 -func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) { +func (client *client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) { res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200) if err != nil { return nil, err @@ -216,7 +211,7 @@ func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL stri } // UploadChunk 上传分片 -func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) { +func (client *client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) { res, err := client.request( ctx, "PUT", uploadURL, content, request.WithContentLength(current.Length()), @@ -247,16 +242,15 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content } // Upload 上传文件 -func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error { - fileInfo := file.Info() +func (client *client) Upload(ctx context.Context, file *fs.UploadRequest) error { // 决定是否覆盖文件 overwrite := "fail" - if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite { + if file.Mode&fs.ModeOverwrite == fs.ModeOverwrite { overwrite = "replace" } - size := int(fileInfo.Size) - dst := fileInfo.SavePath + size := int(file.Props.Size) + dst := file.Props.SavePath // 小文件,使用简单上传接口上传 if size <= int(SmallFileSize) { @@ -272,10 +266,10 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error { } // Initial chunk groups - chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{ - Max: model.GetIntSetting("chunk_retries", 5), + chunks := chunk.NewChunkGroup(file, client.chunkSize, &backoff.ConstantBackoff{ + Max: client.settings.ChunkRetryLimit(ctx), Sleep: chunkRetrySleep, - }, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer"))) + }, client.settings.UseChunkBuffer(ctx), client.l, client.settings.TempPath(ctx)) uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { _, err := client.UploadChunk(ctx, uploadURL, content, current) @@ -285,6 +279,9 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error { // upload chunks for chunks.Next() { if err := chunks.Process(uploadFunc); err != nil { + if err := client.DeleteUploadSession(ctx, uploadURL); err != nil { + client.l.Warning("Failed to delete upload session: %s", err) + } return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) } } @@ -293,7 +290,7 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error { } // DeleteUploadSession 删除上传会话 -func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error { +func (client *client) DeleteUploadSession(ctx context.Context, uploadURL string) error { _, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204) if err != nil { return err @@ -303,7 +300,7 @@ func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) } // SimpleUpload 上传小文件到dst -func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) { +func (client *client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) { options := newDefaultOption() for _, o := range opts { o.apply(options) @@ -334,8 +331,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read // BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为 // 20个一组,调用Delete并行删除 -// TODO 测试 -func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) { +func (client *client) BatchDelete(ctx context.Context, dst []string) ([]string, error) { groupNum := len(dst)/20 + 1 finalRes := make([]string, 0, len(dst)) res := make([]string, 0, 20) @@ -346,6 +342,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, if i == groupNum-1 { end = len(dst) } + + client.l.Debug("Delete file group: %v.", dst[20*i:end]) res, err = client.Delete(ctx, dst[20*i:end]) finalRes = append(finalRes, res...) } @@ -355,7 +353,7 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, // Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误, // 由于API限制,最多删除20个 -func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) { +func (client *client) Delete(ctx context.Context, dst []string) ([]string, error) { body := client.makeBatchDeleteRequestsBody(dst) res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch", WithDriverResource(false)), body, 200) @@ -391,13 +389,13 @@ func getDeleteFailed(res *BatchResponses) []string { } // makeBatchDeleteRequestsBody 生成批量删除请求正文 -func (client *Client) makeBatchDeleteRequestsBody(files []string) string { +func (client *client) makeBatchDeleteRequestsBody(files []string) string { req := BatchRequests{ Requests: make([]BatchRequest, len(files)), } for i, v := range files { v = strings.TrimPrefix(v, "/") - filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/") + filePath, _ := url.Parse("/" + client.endpoints.driverResource + "/root:/") filePath.Path = path.Join(filePath.Path, v) req.Requests[i] = BatchRequest{ ID: v, @@ -411,7 +409,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string { } // GetThumbURL 获取给定尺寸的缩略图URL -func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) { +func (client *client) GetThumbURL(ctx context.Context, dst string) (string, error) { dst = strings.TrimPrefix(dst, "/") requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large" @@ -442,82 +440,6 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s return "", ErrThumbSizeNotFound } -// MonitorUpload 监控客户端分片上传进度 -func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) { - // 回调完成通知chan - callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1) - defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan) - - timeout := model.GetIntSetting("onedrive_monitor_timeout", 600) - interval := model.GetIntSetting("onedrive_callback_check", 20) - - for { - select { - case <-callbackChan: - util.Log().Debug("Client finished OneDrive callback.") - return - case <-time.After(time.Duration(ttl) * time.Second): - // 上传会话到期,仍未完成上传,创建占位符 - client.DeleteUploadSession(context.Background(), uploadURL) - _, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace")) - if err != nil { - util.Log().Debug("Failed to create placeholder file: %s", err) - } - return - case <-time.After(time.Duration(timeout) * time.Second): - util.Log().Debug("Checking OneDrive upload status.") - status, err := client.GetUploadSessionStatus(context.Background(), uploadURL) - - if err != nil { - if resErr, ok := err.(*RespError); ok { - if resErr.APIError.Code == notFoundError { - util.Log().Debug("Upload completed, will check upload callback later.") - select { - case <-time.After(time.Duration(interval) * time.Second): - util.Log().Warning("No callback is made, file will be deleted.") - cache.Deletes([]string{callbackKey}, "callback_") - _, err = client.Delete(context.Background(), []string{path}) - if err != nil { - util.Log().Warning("Failed to delete file without callback: %s", err) - } - case <-callbackChan: - util.Log().Debug("Client finished callback.") - } - return - } - } - util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error()) - continue - } - - // 成功获取分片上传状态,检查文件大小 - if len(status.NextExpectedRanges) == 0 { - continue - } - sizeRange := strings.Split( - status.NextExpectedRanges[len(status.NextExpectedRanges)-1], - "-", - ) - if len(sizeRange) != 2 { - continue - } - uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64) - if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size { - util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...") - // 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传, - // 所以上传一个空文件占位,阻止客户端上传 - client.DeleteUploadSession(context.Background(), uploadURL) - _, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace")) - if err != nil { - util.Log().Debug("无法创建占位文件,%s", err) - } - return - } - - } - } -} - func sysError(err error) *RespError { return &RespError{APIError: APIError{ Code: "system", @@ -525,32 +447,32 @@ func sysError(err error) *RespError { }} } -func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) { +func (client *client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) { // 获取凭证 - err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave") + err := client.UpdateCredential(ctx) if err != nil { return "", sysError(err) } - option = append(option, + opts := []request.Option{ request.WithHeader(http.Header{ - "Authorization": {"Bearer " + client.Credential.AccessToken}, + "Authorization": {"Bearer " + client.credential.String()}, "Content-Type": {"application/json"}, }), request.WithContext(ctx), request.WithTPSLimit( - fmt.Sprintf("policy_%d", client.Policy.ID), - client.Policy.OptionsSerialized.TPSLimit, - client.Policy.OptionsSerialized.TPSLimitBurst, + fmt.Sprintf("policy_%d", client.policy.ID), + client.policy.Settings.TPSLimit, + client.policy.Settings.TPSLimitBurst, ), - ) + } // 发送请求 - res := client.Request.Request( + res := client.httpClient.Request( method, url, body, - option..., + append(opts, option...)..., ) if res.Err != nil { @@ -571,12 +493,12 @@ func (client *Client) request(ctx context.Context, method string, url string, bo if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 { decodeErr = json.Unmarshal([]byte(respBody), &errResp) if decodeErr != nil { - util.Log().Debug("Onedrive returns unknown response: %s", respBody) + client.l.Debug("Onedrive returns unknown response: %s", respBody) return "", sysError(decodeErr) } if res.Response.StatusCode == 429 { - util.Log().Warning("OneDrive request is throttled.") + client.l.Warning("OneDrive request is throttled.") return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header) } @@ -586,7 +508,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo return respBody, nil } -func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) { +func (client *client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) { // 发送请求 bodyReader := io.NopCloser(strings.NewReader(body)) return client.request(ctx, method, url, bodyReader, diff --git a/pkg/filemanager/driver/onedrive/client.go b/pkg/filemanager/driver/onedrive/client.go new file mode 100644 index 00000000..ea5e5dc0 --- /dev/null +++ b/pkg/filemanager/driver/onedrive/client.go @@ -0,0 +1,90 @@ +package onedrive + +import ( + "context" + "errors" + "io" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + + "github.com/cloudreve/Cloudreve/v4/pkg/request" +) + +var ( + // ErrAuthEndpoint 无法解析授权端点地址 + ErrAuthEndpoint = errors.New("failed to parse endpoint url") + // ErrInvalidRefreshToken 上传策略无有效的RefreshToken + ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy") + // ErrDeleteFile 无法删除文件 + ErrDeleteFile = errors.New("cannot delete file") + // ErrClientCanceled 客户端取消操作 + ErrClientCanceled = errors.New("client canceled") + // Desired thumb size not available + ErrThumbSizeNotFound = errors.New("thumb size not found") +) + +type Client interface { + ListChildren(ctx context.Context, path string) ([]FileInfo, error) + Meta(ctx context.Context, id string, path string) (*FileInfo, error) + CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) + GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) + GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) + Upload(ctx context.Context, file *fs.UploadRequest) error + SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) + DeleteUploadSession(ctx context.Context, uploadURL string) error + BatchDelete(ctx context.Context, dst []string) ([]string, error) + GetThumbURL(ctx context.Context, dst string) (string, error) + OAuthURL(ctx context.Context, scopes []string) string + ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) +} + +// client OneDrive客户端 +type client struct { + endpoints *endpoints + policy *ent.StoragePolicy + credential credmanager.Credential + + httpClient request.Client + cred credmanager.CredManager + l logging.Logger + settings setting.Provider + + chunkSize int64 +} + +// endpoints OneDrive客户端相关设置 +type endpoints struct { + oAuthEndpoints *oauthEndpoint + endpointURL string // 接口请求的基URL + driverResource string // 要使用的驱动器 +} + +// NewClient 根据存储策略获取新的client +func NewClient(policy *ent.StoragePolicy, httpClient request.Client, cred credmanager.CredManager, + l logging.Logger, settings setting.Provider, chunkSize int64) Client { + client := &client{ + endpoints: &endpoints{ + endpointURL: policy.Server, + driverResource: policy.Settings.OdDriver, + }, + policy: policy, + httpClient: httpClient, + cred: cred, + l: l, + settings: settings, + chunkSize: chunkSize, + } + + if client.endpoints.driverResource == "" { + client.endpoints.driverResource = "me/drive" + } + + oauthBase := getOAuthEndpoint(policy.Server) + client.endpoints.oAuthEndpoints = oauthBase + + return client +} diff --git a/pkg/filemanager/driver/onedrive/oauth.go b/pkg/filemanager/driver/onedrive/oauth.go new file mode 100644 index 00000000..4f94402a --- /dev/null +++ b/pkg/filemanager/driver/onedrive/oauth.go @@ -0,0 +1,271 @@ +package onedrive + +import ( + "context" + "encoding/gob" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/samber/lo" +) + +const ( + AccessTokenExpiryMargin = 600 // 10 minutes +) + +// Error 实现error接口 +func (err OAuthError) Error() string { + return err.ErrorDescription +} + +// OAuthURL 获取OAuth认证页面URL +func (client *client) OAuthURL(ctx context.Context, scope []string) string { + query := url.Values{ + "client_id": {client.policy.BucketName}, + "scope": {strings.Join(scope, " ")}, + "response_type": {"code"}, + "redirect_uri": {client.policy.Settings.OauthRedirect}, + "state": {strconv.Itoa(client.policy.ID)}, + } + client.endpoints.oAuthEndpoints.authorize.RawQuery = query.Encode() + return client.endpoints.oAuthEndpoints.authorize.String() +} + +// getOAuthEndpoint gets OAuth endpoints from API endpoint +func getOAuthEndpoint(apiEndpoint string) *oauthEndpoint { + base, err := url.Parse(apiEndpoint) + if err != nil { + return nil + } + var ( + token *url.URL + authorize *url.URL + ) + switch base.Host { + //case "login.live.com": + // token, _ = url.Parse("https://login.live.com/oauth20_token.srf") + // authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf") + case "microsoftgraph.chinacloudapi.cn": + token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token") + authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize") + default: + token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token") + authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize") + } + + return &oauthEndpoint{ + token: *token, + authorize: *authorize, + } +} + +// Credential 获取token时返回的凭证 +type Credential struct { + ExpiresIn int64 `json:"expires_in"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + RefreshedAtUnix int64 `json:"refreshed_at"` + + PolicyID int `json:"policy_id"` +} + +func init() { + gob.Register(Credential{}) +} + +func (c Credential) Refresh(ctx context.Context) (credmanager.Credential, error) { + if c.RefreshToken == "" { + return nil, ErrInvalidRefreshToken + } + + dep := dependency.FromContext(ctx) + storagePolicyClient := dep.StoragePolicyClient() + policy, err := storagePolicyClient.GetPolicyByID(ctx, c.PolicyID) + if err != nil { + return nil, fmt.Errorf("failed to get storage policy: %w", err) + } + + oauthBase := getOAuthEndpoint(policy.Server) + + newCredential, err := obtainToken(ctx, &obtainTokenArgs{ + clientId: policy.BucketName, + redirect: policy.Settings.OauthRedirect, + secret: policy.SecretKey, + refreshToken: c.RefreshToken, + client: dep.RequestClient(request.WithLogger(dep.Logger())), + tokenEndpoint: oauthBase.token.String(), + policyID: c.PolicyID, + }) + + if err != nil { + return nil, err + } + + c.RefreshToken = newCredential.RefreshToken + c.AccessToken = newCredential.AccessToken + c.ExpiresIn = newCredential.ExpiresIn + c.RefreshedAtUnix = time.Now().Unix() + + // Write refresh token to db + if err := storagePolicyClient.UpdateAccessKey(ctx, policy, newCredential.RefreshToken); err != nil { + return nil, err + } + + return c, nil +} + +func (c Credential) Key() string { + return CredentialKey(c.PolicyID) +} + +func (c Credential) Expiry() time.Time { + return time.Unix(c.ExpiresIn-AccessTokenExpiryMargin, 0) +} + +func (c Credential) String() string { + return c.AccessToken +} + +func (c Credential) RefreshedAt() *time.Time { + if c.RefreshedAtUnix == 0 { + return nil + } + refreshedAt := time.Unix(c.RefreshedAtUnix, 0) + return &refreshedAt +} + +// ObtainToken 通过code或refresh_token兑换token +func (client *client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) { + options := newDefaultOption() + for _, o := range opts { + o.apply(options) + } + + return obtainToken(ctx, &obtainTokenArgs{ + clientId: client.policy.BucketName, + redirect: client.policy.Settings.OauthRedirect, + secret: client.policy.SecretKey, + code: options.code, + refreshToken: options.refreshToken, + client: client.httpClient, + tokenEndpoint: client.endpoints.oAuthEndpoints.token.String(), + policyID: client.policy.ID, + }) + +} + +type obtainTokenArgs struct { + clientId string + redirect string + secret string + code string + refreshToken string + client request.Client + tokenEndpoint string + policyID int +} + +// obtainToken fetch new access token from Microsoft Graph API +func obtainToken(ctx context.Context, args *obtainTokenArgs) (*Credential, error) { + body := url.Values{ + "client_id": {args.clientId}, + "redirect_uri": {args.redirect}, + "client_secret": {args.secret}, + } + if args.code != "" { + body.Add("grant_type", "authorization_code") + body.Add("code", args.code) + } else { + body.Add("grant_type", "refresh_token") + body.Add("refresh_token", args.refreshToken) + } + strBody := body.Encode() + + res := args.client.Request( + "POST", + args.tokenEndpoint, + io.NopCloser(strings.NewReader(strBody)), + request.WithHeader(http.Header{ + "Content-Type": {"application/x-www-form-urlencoded"}}, + ), + request.WithContentLength(int64(len(strBody))), + request.WithContext(ctx), + ) + if res.Err != nil { + return nil, res.Err + } + + respBody, err := res.GetResponse() + if err != nil { + return nil, err + } + + var ( + errResp OAuthError + credential Credential + decodeErr error + ) + + if res.Response.StatusCode != 200 { + decodeErr = json.Unmarshal([]byte(respBody), &errResp) + } else { + decodeErr = json.Unmarshal([]byte(respBody), &credential) + } + if decodeErr != nil { + return nil, decodeErr + } + + if errResp.ErrorType != "" { + return nil, errResp + } + + credential.PolicyID = args.policyID + credential.ExpiresIn = time.Now().Unix() + credential.ExpiresIn + if args.code != "" { + credential.ExpiresIn = time.Now().Unix() - 10 + } + return &credential, nil +} + +// UpdateCredential 更新凭证,并检查有效期 +func (client *client) UpdateCredential(ctx context.Context) error { + newCred, err := client.cred.Obtain(ctx, CredentialKey(client.policy.ID)) + if err != nil { + return fmt.Errorf("failed to obtain token from CredManager: %w", err) + } + + client.credential = newCred + return nil +} + +// RetrieveOneDriveCredentials retrieves OneDrive credentials from DB inventory +func RetrieveOneDriveCredentials(ctx context.Context, storagePolicyClient inventory.StoragePolicyClient) ([]credmanager.Credential, error) { + odPolicies, err := storagePolicyClient.ListPolicyByType(ctx, types.PolicyTypeOd) + if err != nil { + return nil, fmt.Errorf("failed to list OneDrive policies: %w", err) + } + + return lo.Map(odPolicies, func(item *ent.StoragePolicy, index int) credmanager.Credential { + return &Credential{ + PolicyID: item.ID, + ExpiresIn: 0, + RefreshToken: item.AccessKey, + } + }), nil +} + +func CredentialKey(policyId int) string { + return fmt.Sprintf("cred_od_%d", policyId) +} diff --git a/pkg/filemanager/driver/onedrive/onedrive.go b/pkg/filemanager/driver/onedrive/onedrive.go new file mode 100644 index 00000000..5ac5ecf7 --- /dev/null +++ b/pkg/filemanager/driver/onedrive/onedrive.go @@ -0,0 +1,247 @@ +package onedrive + +import ( + "context" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "net/url" + "os" + "strings" + "time" +) + +// Driver OneDrive 适配器 +type Driver struct { + policy *ent.StoragePolicy + client Client + settings setting.Provider + config conf.ConfigProvider + l logging.Logger + chunkSize int64 +} + +var ( + features = &boolset.BooleanSet{} +) + +const ( + streamSaverParam = "stream_saver" +) + +func init() { + boolset.Sets(map[driver.HandlerCapability]bool{ + driver.HandlerCapabilityUploadSentinelRequired: true, + }, features) +} + +// NewDriver 从存储策略初始化新的Driver实例 +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, cred credmanager.CredManager) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 50 << 20 // 50MB + } + + c := NewClient(policy, request.NewClient(config, request.WithLogger(l)), cred, l, settings, chunkSize) + + return &Driver{ + policy: policy, + client: c, + settings: settings, + l: l, + config: config, + chunkSize: chunkSize, + }, nil +} + +//// List 列取项目 +//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// base = strings.TrimPrefix(base, "/") +// // 列取子项目 +// objects, _ := handler.client.ListChildren(ctx, base) +// +// // 获取真实的列取起始根目录 +// rootPath := base +// if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok { +// rootPath = realBase +// } else { +// ctx = context.WithValue(ctx, fsctx.PathCtx, base) +// } +// +// // 整理结果 +// res := make([]response.Object, 0, len(objects)) +// for _, object := range objects { +// source := path.Join(base, object.Name) +// rel, err := filepath.Rel(rootPath, source) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: object.Name, +// RelativePath: filepath.ToSlash(rel), +// Source: source, +// Size: uint64(object.Size), +// IsDir: object.Folder != nil, +// LastModify: time.Now(), +// }) +// } +// +// // 递归列取子目录 +// if recursive { +// for _, object := range objects { +// if object.Folder != nil { +// sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive) +// res = append(res, sub...) +// } +// } +// } +// +// return res, nil +//} + +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + return handler.client.Upload(ctx, file) +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + return handler.client.BatchDelete(ctx, files) +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + res, err := handler.client.GetThumbURL(ctx, e.Source()) + if err != nil { + var apiErr *RespError + if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) { + // OneDrive cannot generate thumbnail for this file + return "", fmt.Errorf("thumb not supported in OneDrive: %w", err) + } + } + + return res, nil +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + // 缓存不存在,重新获取 + res, err := handler.client.Meta(ctx, "", e.Source()) + if err != nil { + return "", err + } + + if args.IsDownload && handler.policy.Settings.StreamSaver { + downloadUrl := res.DownloadURL + "&" + streamSaverParam + "=" + url.QueryEscape(args.DisplayName) + return downloadUrl, nil + } + + return res.DownloadURL, nil +} + +// Token 获取上传会话URL +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeOd, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + uploadURL, err := handler.client.CreateUploadSession(ctx, file.Props.SavePath, WithConflictBehavior("fail")) + if err != nil { + return nil, err + } + + // 监控回调及上传 + //go handler.client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl) + + uploadSession.ChunkSize = handler.chunkSize + uploadSession.UploadURL = uploadURL + return &fs.UploadCredential{ + ChunkSize: handler.chunkSize, + UploadURLs: []string{uploadURL}, + }, nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + err := handler.client.DeleteUploadSession(ctx, uploadSession.UploadURL) + // Create empty placeholder file to stop upload + if err == nil { + _, err := handler.client.SimpleUpload(ctx, uploadSession.Props.SavePath, strings.NewReader(""), 0, WithConflictBehavior("replace")) + if err != nil { + handler.l.Warning("Failed to create placeholder file %q:%s", uploadSession.Props.SavePath, err) + } + } + + return err +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + if session.SentinelTaskID == 0 { + return nil + } + + // Make sure uploaded file size is correct + res, err := handler.client.Meta(ctx, "", session.Props.SavePath) + if err != nil { + // Create empty placeholder file to stop further upload + + return fmt.Errorf("failed to get uploaded file size: %w", err) + } + + isSharePoint := strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.com") || + strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.cn") + sizeMismatch := res.Size != session.Props.Size + // SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容 + // See: https://github.com/OneDrive/onedrive-api-docs/issues/935 + if isSharePoint && sizeMismatch && (res.Size > session.Props.Size) && (res.Size-session.Props.Size <= 1048576) { + sizeMismatch = false + } + + if sizeMismatch { + return serializer.NewError( + serializer.CodeMetaMismatch, + fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.Size), + nil, + ) + } + + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + return &driver.Capabilities{ + StaticFeatures: features, + ThumbSupportedExts: handler.policy.Settings.ThumbExts, + ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts, + ThumbMaxSize: handler.policy.Settings.ThumbMaxSize, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + return nil, errors.New("not implemented") +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} diff --git a/pkg/filesystem/driver/onedrive/options.go b/pkg/filemanager/driver/onedrive/options.go similarity index 100% rename from pkg/filesystem/driver/onedrive/options.go rename to pkg/filemanager/driver/onedrive/options.go diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filemanager/driver/onedrive/types.go similarity index 89% rename from pkg/filesystem/driver/onedrive/types.go rename to pkg/filemanager/driver/onedrive/types.go index 2a2ea4ce..f430d3ec 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filemanager/driver/onedrive/types.go @@ -27,7 +27,7 @@ type UploadSessionResponse struct { // FileInfo 文件元信息 type FileInfo struct { Name string `json:"name"` - Size uint64 `json:"size"` + Size int64 `json:"size"` Image imageInfo `json:"image"` ParentReference parentReference `json:"parentReference"` DownloadURL string `json:"@microsoft.graph.downloadUrl"` @@ -104,16 +104,6 @@ type oauthEndpoint struct { authorize url.URL } -// Credential 获取token时返回的凭证 -type Credential struct { - TokenType string `json:"token_type"` - ExpiresIn int64 `json:"expires_in"` - Scope string `json:"scope"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - UserID string `json:"user_id"` -} - // OAuthError OAuth相关接口的错误响应 type OAuthError struct { ErrorType string `json:"error"` diff --git a/pkg/filesystem/driver/oss/callback.go b/pkg/filemanager/driver/oss/callback.go similarity index 68% rename from pkg/filesystem/driver/oss/callback.go rename to pkg/filemanager/driver/oss/callback.go index b2a88035..36a70bb2 100644 --- a/pkg/filesystem/driver/oss/callback.go +++ b/pkg/filemanager/driver/oss/callback.go @@ -10,39 +10,44 @@ import ( "encoding/pem" "errors" "fmt" - "io/ioutil" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "io" "net/http" "net/url" "strings" +) - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/request" +const ( + pubKeyHeader = "x-oss-pub-key-url" + pubKeyPrefix = "http://gosspublic.alicdn.com/" + pubKeyPrefixHttps = "https://gosspublic.alicdn.com/" + pubKeyCacheKey = "oss_public_key" ) // GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥 -func GetPublicKey(r *http.Request) ([]byte, error) { +func GetPublicKey(r *http.Request, kv cache.Driver, client request.Client) ([]byte, error) { var pubKey []byte // 尝试从缓存中获取 - pub, exist := cache.Get("oss_public_key") + pub, exist := kv.Get(pubKeyCacheKey) if exist { return pub.([]byte), nil } // 从请求中获取 - pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url")) + pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get(pubKeyHeader)) if err != nil { return pubKey, err } // 确保这个 public key 是由 OSS 颁发的 - if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") && - !strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") { + if !strings.HasPrefix(string(pubURL), pubKeyPrefix) && + !strings.HasPrefix(string(pubURL), pubKeyPrefixHttps) { return pubKey, errors.New("public key url invalid") } // 获取公钥 - client := request.NewClient() body, err := client.Request("GET", string(pubURL), nil). CheckHTTPResponse(200). GetResponse() @@ -51,7 +56,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) { } // 写入缓存 - _ = cache.Set("oss_public_key", []byte(body), 86400*7) + _ = kv.Set(pubKeyCacheKey, []byte(body), 86400*7) return []byte(body), nil } @@ -60,12 +65,12 @@ func getRequestMD5(r *http.Request) ([]byte, error) { var byteMD5 []byte // 获取请求正文 - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) r.Body.Close() if err != nil { return byteMD5, err } - r.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.Body = io.NopCloser(bytes.NewReader(body)) strURLPathDecode, err := url.PathUnescape(r.URL.Path) if err != nil { @@ -81,8 +86,8 @@ func getRequestMD5(r *http.Request) ([]byte, error) { } // VerifyCallbackSignature 验证OSS回调请求 -func VerifyCallbackSignature(r *http.Request) error { - bytePublicKey, err := GetPublicKey(r) +func VerifyCallbackSignature(r *http.Request, kv cache.Driver, client request.Client) error { + bytePublicKey, err := GetPublicKey(r, kv, client) if err != nil { return err } diff --git a/pkg/filemanager/driver/oss/media.go b/pkg/filemanager/driver/oss/media.go new file mode 100644 index 00000000..210eec16 --- /dev/null +++ b/pkg/filemanager/driver/oss/media.go @@ -0,0 +1,359 @@ +package oss + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/samber/lo" + "math" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + imageInfoProcess = "image/info" + videoInfoProcess = "video/info" + audioInfoProcess = "audio/info" + mediaInfoTTL = time.Duration(10) * time.Minute +) + +var ( + supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff", "heic", "heif"} + supportedAudioExt = []string{"mp3", "wav", "flac", "aac", "m4a", "ogg", "wma", "ape", "alac", "amr", "opus"} + supportedVideoExt = []string{"mp4", "mkv", "avi", "mov", "flv", "wmv", "rmvb", "webm", "3gp", "mpg", "mpeg", "m4v", "ts", "m3u8", "vob", "f4v", "rm", "asf", "divx", "ogv", "dat", "mts", "m2ts", "swf", "avi", "3g2", "m2v", "m4p", "m4b", "m4r", "m4v", "m4a"} +) + +type ( + ImageProp struct { + Value string `json:"value"` + } + ImageInfo map[string]ImageProp + + Error struct { + XMLName xml.Name `xml:"Error"` + Text string `xml:",chardata"` + Code string `xml:"Code"` + Message string `xml:"Message"` + RequestId string `xml:"RequestId"` + HostId string `xml:"HostId"` + EC string `xml:"EC"` + RecommendDoc string `xml:"RecommendDoc"` + } + + StreamMediaInfo struct { + RequestID string `json:"RequestId"` + Language string `json:"Language"` + Title string `json:"Title"` + VideoStreams []VideoStream `json:"VideoStreams"` + AudioStreams []AudioStream `json:"AudioStreams"` + Subtitles []Subtitle `json:"Subtitles"` + StreamCount int64 `json:"StreamCount"` + ProgramCount int64 `json:"ProgramCount"` + FormatName string `json:"FormatName"` + FormatLongName string `json:"FormatLongName"` + Size int64 `json:"Size"` + StartTime float64 `json:"StartTime"` + Bitrate int64 `json:"Bitrate"` + Artist string `json:"Artist"` + AlbumArtist string `json:"AlbumArtist"` + Composer string `json:"Composer"` + Performer string `json:"Performer"` + Album string `json:"Album"` + Duration float64 `json:"Duration"` + ProduceTime string `json:"ProduceTime"` + LatLong string `json:"LatLong"` + VideoWidth int64 `json:"VideoWidth"` + VideoHeight int64 `json:"VideoHeight"` + Addresses []Address `json:"Addresses"` + } + + Address struct { + Language string `json:"Language"` + AddressLine string `json:"AddressLine"` + Country string `json:"Country"` + Province string `json:"Province"` + City string `json:"City"` + District string `json:"District"` + Township string `json:"Township"` + } + + AudioStream struct { + Index int `json:"Index"` + Language string `json:"Language"` + CodecName string `json:"CodecName"` + CodecLongName string `json:"CodecLongName"` + CodecTimeBase string `json:"CodecTimeBase"` + CodecTagString string `json:"CodecTagString"` + CodecTag string `json:"CodecTag"` + TimeBase string `json:"TimeBase"` + StartTime float64 `json:"StartTime"` + Duration float64 `json:"Duration"` + Bitrate int64 `json:"Bitrate"` + FrameCount int64 `json:"FrameCount"` + Lyric string `json:"Lyric"` + SampleFormat string `json:"SampleFormat"` + SampleRate int64 `json:"SampleRate"` + Channels int64 `json:"Channels"` + ChannelLayout string `json:"ChannelLayout"` + } + + Subtitle struct { + Index int64 `json:"Index"` + Language string `json:"Language"` + CodecName string `json:"CodecName"` + CodecLongName string `json:"CodecLongName"` + CodecTagString string `json:"CodecTagString"` + CodecTag string `json:"CodecTag"` + StartTime float64 `json:"StartTime"` + Duration float64 `json:"Duration"` + Bitrate int64 `json:"Bitrate"` + Content string `json:"Content"` + Width int64 `json:"Width"` + Height int64 `json:"Height"` + } + + VideoStream struct { + Index int `json:"Index"` + Language string `json:"Language"` + CodecName string `json:"CodecName"` + CodecLongName string `json:"CodecLongName"` + Profile string `json:"Profile"` + CodecTimeBase string `json:"CodecTimeBase"` + CodecTagString string `json:"CodecTagString"` + CodecTag string `json:"CodecTag"` + Width int `json:"Width"` + Height int `json:"Height"` + HasBFrames int `json:"HasBFrames"` + SampleAspectRatio string `json:"SampleAspectRatio"` + DisplayAspectRatio string `json:"DisplayAspectRatio"` + PixelFormat string `json:"PixelFormat"` + Level int `json:"Level"` + FrameRate string `json:"FrameRate"` + AverageFrameRate string `json:"AverageFrameRate"` + TimeBase string `json:"TimeBase"` + StartTime float64 `json:"StartTime"` + Duration float64 `json:"Duration"` + Bitrate int64 `json:"Bitrate"` + FrameCount int64 `json:"FrameCount"` + Rotate string `json:"Rotate"` + BitDepth int `json:"BitDepth"` + ColorSpace string `json:"ColorSpace"` + ColorRange string `json:"ColorRange"` + ColorTransfer string `json:"ColorTransfer"` + ColorPrimaries string `json:"ColorPrimaries"` + } +) + +func (handler *Driver) extractIMMMeta(ctx context.Context, path, category string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, category, true) + if err != nil { + return nil, err + } + + var info StreamMediaInfo + if err := json.Unmarshal([]byte(resp), &info); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + streams := lo.Map(info.VideoStreams, func(stream VideoStream, index int) mediameta.Stream { + bitrate := "" + if stream.Bitrate != 0 { + bitrate = strconv.FormatInt(stream.Bitrate, 10) + } + return mediameta.Stream{ + Index: stream.Index, + CodecName: stream.CodecName, + CodecLongName: stream.CodecLongName, + CodecType: "video", + Width: stream.Width, + Height: stream.Height, + Duration: strconv.FormatFloat(stream.Duration, 'f', -1, 64), + Bitrate: bitrate, + } + }) + streams = append(streams, lo.Map(info.AudioStreams, func(stream AudioStream, index int) mediameta.Stream { + bitrate := "" + if stream.Bitrate != 0 { + bitrate = strconv.FormatInt(stream.Bitrate, 10) + } + return mediameta.Stream{ + Index: stream.Index, + CodecName: stream.CodecName, + CodecLongName: stream.CodecLongName, + CodecType: "audio", + Duration: strconv.FormatFloat(stream.Duration, 'f', -1, 64), + Bitrate: bitrate, + } + })...) + + metas := make([]driver.MediaMeta, 0) + metas = append(metas, mediameta.ProbeMetaTransform(&mediameta.FFProbeMeta{ + Format: &mediameta.Format{ + FormatName: info.FormatName, + FormatLongName: info.FormatLongName, + Duration: strconv.FormatFloat(info.Duration, 'f', -1, 64), + Bitrate: strconv.FormatInt(info.Bitrate, 10), + }, + Streams: streams, + })...) + + if info.Artist != "" { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicArtist, + Value: info.Artist, + Type: driver.MediaTypeMusic, + }) + } + + if info.AlbumArtist != "" { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicAlbumArtists, + Value: info.AlbumArtist, + Type: driver.MediaTypeMusic, + }) + } + + if info.Composer != "" { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicComposer, + Value: info.Composer, + Type: driver.MediaTypeMusic, + }) + } + + if info.Album != "" { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicAlbum, + Value: info.Album, + Type: driver.MediaTypeMusic, + }) + } + + return metas, nil +} + +func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, imageInfoProcess, false) + if err != nil { + return nil, err + } + + var imageInfo ImageInfo + if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + metas := make([]driver.MediaMeta, 0) + exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) { + return key, value.Value + }) + metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...) + metas = append(metas, parseGpsInfo(imageInfo)...) + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + + return metas, nil +} + +// extractMediaInfo Sends API calls to OSS IMM service to extract media info. +func (handler *Driver) extractMediaInfo(ctx context.Context, path string, category string, forceSign bool) (string, error) { + mediaOption := []oss.Option{oss.Process(category)} + mediaInfoExpire := time.Now().Add(mediaInfoTTL) + thumbURL, err := handler.signSourceURL( + ctx, + path, + &mediaInfoExpire, + mediaOption, + forceSign, + ) + if err != nil { + return "", fmt.Errorf("failed to sign media info url: %w", err) + } + + resp, err := handler.httpClient. + Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)). + CheckHTTPResponse(http.StatusOK). + GetResponseIgnoreErr() + if err != nil { + return "", handleOssError(resp, err) + } + + return resp, nil +} + +func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta { + latitude := imageInfo["GPSLatitude"] // 31deg 16.26808' + longitude := imageInfo["GPSLongitude"] // 120deg 42.91039' + latRef := imageInfo["GPSLatitudeRef"] // North + lonRef := imageInfo["GPSLongitudeRef"] // East + + // Make sure all value exist in map + if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" { + return nil + } + + lat := parseRawGPS(latitude.Value, latRef.Value) + lon := parseRawGPS(longitude.Value, lonRef.Value) + if !math.IsNaN(lat) && !math.IsNaN(lon) { + lat, lng := mediameta.NormalizeGPS(lat, lon) + return []driver.MediaMeta{{ + Key: mediameta.GpsLat, + Value: fmt.Sprintf("%f", lat), + }, { + Key: mediameta.GpsLng, + Value: fmt.Sprintf("%f", lng), + }} + } + + return nil +} + +func parseRawGPS(gpsStr string, ref string) float64 { + elem := strings.Split(gpsStr, " ") + if len(elem) < 1 { + return 0 + } + + var ( + deg float64 + minutes float64 + seconds float64 + ) + + deg, _ = strconv.ParseFloat(strings.TrimSuffix(elem[0], "deg"), 64) + if len(elem) >= 2 { + minutes, _ = strconv.ParseFloat(strings.TrimSuffix(elem[1], "'"), 64) + } + if len(elem) >= 3 { + seconds, _ = strconv.ParseFloat(strings.TrimSuffix(elem[2], "\""), 64) + } + + decimal := deg + minutes/60.0 + seconds/3600.0 + + if ref == "South" || ref == "West" { + return -decimal + } + + return decimal +} + +func handleOssError(resp string, originErr error) error { + if resp == "" { + return originErr + } + + var err Error + if err := xml.Unmarshal([]byte(resp), &err); err != nil { + return fmt.Errorf("failed to unmarshal oss error: %w", err) + } + + return fmt.Errorf("oss error: %s", err.Message) +} diff --git a/pkg/filemanager/driver/oss/oss.go b/pkg/filemanager/driver/oss/oss.go new file mode 100644 index 00000000..dcad2757 --- /dev/null +++ b/pkg/filemanager/driver/oss/oss.go @@ -0,0 +1,548 @@ +package oss + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/samber/lo" +) + +// UploadPolicy 阿里云OSS上传策略 +type UploadPolicy struct { + Expiration string `json:"expiration"` + Conditions []interface{} `json:"conditions"` +} + +// CallbackPolicy 回调策略 +type CallbackPolicy struct { + CallbackURL string `json:"callbackUrl"` + CallbackBody string `json:"callbackBody"` + CallbackBodyType string `json:"callbackBodyType"` + CallbackSNI bool `json:"callbackSNI"` +} + +// Driver 阿里云OSS策略适配器 +type Driver struct { + policy *ent.StoragePolicy + + client *oss.Client + bucket *oss.Bucket + settings setting.Provider + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + httpClient request.Client + + chunkSize int64 +} + +type key int + +const ( + chunkRetrySleep = time.Duration(5) * time.Second + uploadIdParam = "uploadId" + partNumberParam = "partNumber" + callbackParam = "callback" + completeAllHeader = "x-oss-complete-all" + maxDeleteBatch = 1000 + + // MultiPartUploadThreshold 服务端使用分片上传的阈值 + MultiPartUploadThreshold int64 = 5 * (1 << 30) // 5GB +) + +var ( + features = &boolset.BooleanSet{} +) + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 25 << 20 // 25 MB + } + + driver := &Driver{ + policy: policy, + settings: settings, + chunkSize: chunkSize, + config: config, + l: l, + mime: mime, + httpClient: request.NewClient(config, request.WithLogger(l)), + } + + return driver, driver.InitOSSClient(false) +} + +// CORS 创建跨域策略 +func (handler *Driver) CORS() error { + return handler.client.SetBucketCORS(handler.policy.BucketName, []oss.CORSRule{ + { + AllowedOrigin: []string{"*"}, + AllowedMethod: []string{ + "GET", + "POST", + "PUT", + "DELETE", + "HEAD", + }, + ExposeHeader: []string{}, + AllowedHeader: []string{"*"}, + MaxAgeSeconds: 3600, + }, + }) +} + +// InitOSSClient 初始化OSS鉴权客户端 +func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error { + if handler.policy == nil { + return errors.New("empty policy") + } + + opt := make([]oss.ClientOption, 0) + + // 决定是否使用内网 Endpoint + endpoint := handler.policy.Server + if handler.policy.Settings.ServerSideEndpoint != "" && !forceUsePublicEndpoint { + endpoint = handler.policy.Settings.ServerSideEndpoint + } else if handler.policy.Settings.UseCname { + opt = append(opt, oss.UseCname(true)) + } + + if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { + endpoint = "https://" + endpoint + } + + // 初始化客户端 + client, err := oss.New(endpoint, handler.policy.AccessKey, handler.policy.SecretKey, opt...) + if err != nil { + return err + } + handler.client = client + + // 初始化存储桶 + bucket, err := client.Bucket(handler.policy.BucketName) + if err != nil { + return err + } + handler.bucket = bucket + + return nil +} + +//// List 列出OSS上的文件 +//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// // 列取文件 +// base = strings.TrimPrefix(base, "/") +// if base != "" { +// base += "/" +// } +// +// var ( +// delimiter string +// marker string +// objects []oss.ObjectProperties +// commons []string +// ) +// if !recursive { +// delimiter = "/" +// } +// +// for { +// subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base), +// oss.MaxKeys(1000), oss.Delimiter(delimiter)) +// if err != nil { +// return nil, err +// } +// objects = append(objects, subRes.Objects...) +// commons = append(commons, subRes.CommonPrefixes...) +// marker = subRes.NextMarker +// if marker == "" { +// break +// } +// } +// +// // 处理列取结果 +// res := make([]response.Object, 0, len(objects)+len(commons)) +// // 处理目录 +// for _, object := range commons { +// rel, err := filepath.Rel(base, object) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object), +// RelativePath: filepath.ToSlash(rel), +// Size: 0, +// IsDir: true, +// LastModify: time.Now(), +// }) +// } +// // 处理文件 +// for _, object := range objects { +// rel, err := filepath.Rel(base, object.Key) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object.Key), +// Source: object.Key, +// RelativePath: filepath.ToSlash(rel), +// Size: uint64(object.Size), +// IsDir: false, +// LastModify: object.LastModified, +// }) +// } +// +// return res, nil +//} + +// Get 获取文件 +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + // 凭证有效期 + credentialTTL := handler.settings.UploadSessionTTL(ctx) + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + options := []oss.Option{ + oss.WithContext(ctx), + oss.Expires(time.Now().Add(credentialTTL * time.Second)), + oss.ForbidOverWrite(!overwrite), + oss.ContentType(mimeType), + } + + // 小文件直接上传 + if file.Props.Size < MultiPartUploadThreshold { + return handler.bucket.PutObject(file.Props.SavePath, file, options...) + } + + // 超过阈值时使用分片上传 + imur, err := handler.bucket.InitiateMultipartUpload(file.Props.SavePath, options...) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + parts := make([]oss.UploadPart, 0) + + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{ + Max: handler.settings.ChunkRetryLimit(ctx), + Sleep: chunkRetrySleep, + }, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx)) + + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + part, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1, oss.WithContext(ctx)) + if err == nil { + parts = append(parts, part) + } + return err + } + + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + handler.cancelUpload(imur) + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + _, err = handler.bucket.CompleteMultipartUpload(imur, parts, oss.ForbidOverWrite(!overwrite), oss.WithContext(ctx)) + if err != nil { + handler.cancelUpload(imur) + } + + return err +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + groups := lo.Chunk(files, maxDeleteBatch) + failed := make([]string, 0) + var lastError error + for index, group := range groups { + handler.l.Debug("Process delete group #%d: %v", index, group) + // 删除文件 + delRes, err := handler.bucket.DeleteObjects(group) + if err != nil { + failed = append(failed, group...) + lastError = err + continue + } + + // 统计未删除的文件 + failed = append(failed, util.SliceDifference(files, delRes.DeletedObjects)...) + } + + if len(failed) > 0 && lastError == nil { + lastError = fmt.Errorf("failed to delete files: %v", failed) + } + + return failed, lastError +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + usePublicEndpoint := true + if forceUsePublicEndpoint, ok := ctx.Value(driver.ForceUsePublicEndpointCtx{}).(bool); ok { + usePublicEndpoint = forceUsePublicEndpoint + } + + // 初始化客户端 + if err := handler.InitOSSClient(usePublicEndpoint); err != nil { + return "", err + } + + w, h := handler.settings.ThumbSize(ctx) + thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d", h, w) + thumbOption := []oss.Option{oss.Process(thumbParam)} + thumbURL, err := handler.signSourceURL( + ctx, + e.Source(), + expire, + thumbOption, + false, + ) + if err != nil { + return "", err + } + + return thumbURL, nil +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + // 初始化客户端 + usePublicEndpoint := true + if forceUsePublicEndpoint, ok := ctx.Value(driver.ForceUsePublicEndpointCtx{}).(bool); ok { + usePublicEndpoint = forceUsePublicEndpoint + } + if err := handler.InitOSSClient(usePublicEndpoint); err != nil { + return "", err + } + + // 添加各项设置 + var signOptions = make([]oss.Option, 0, 2) + if args.IsDownload { + encodedFilename := url.PathEscape(args.DisplayName) + signOptions = append(signOptions, oss.ResponseContentDisposition(fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + encodedFilename, encodedFilename))) + } + if args.Speed > 0 { + // Byte 转换为 bit + args.Speed *= 8 + + // OSS对速度值有范围限制 + if args.Speed < 819200 { + args.Speed = 819200 + } + if args.Speed > 838860800 { + args.Speed = 838860800 + } + signOptions = append(signOptions, oss.TrafficLimitParam(args.Speed)) + } + + return handler.signSourceURL(ctx, e.Source(), args.Expire, signOptions, false) +} + +func (handler *Driver) signSourceURL(ctx context.Context, path string, expire *time.Time, options []oss.Option, forceSign bool) (string, error) { + ttl := int64(86400 * 365 * 20) + if expire != nil { + ttl = int64(time.Until(*expire).Seconds()) + } + + signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...) + if err != nil { + return "", err + } + + // 将最终生成的签名URL域名换成用户自定义的加速域名(如果有) + finalURL, err := url.Parse(signedURL) + if err != nil { + return "", err + } + + // 公有空间替换掉Key及不支持的头 + if !handler.policy.IsPrivate && !forceSign { + query := finalURL.Query() + query.Del("OSSAccessKeyId") + query.Del("Signature") + query.Del("response-content-disposition") + query.Del("x-oss-traffic-limit") + finalURL.RawQuery = query.Encode() + } + return finalURL.String(), nil +} + +// Token 获取上传策略和认证Token +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // 初始化客户端 + if err := handler.InitOSSClient(true); err != nil { + return nil, err + } + + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + // 在从机端创建上传会话 + uploadSession.ChunkSize = handler.chunkSize + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeOss, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + // 回调策略 + callbackPolicy := CallbackPolicy{ + CallbackURL: uploadSession.Callback, + CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`, + CallbackBodyType: "application/json", + CallbackSNI: true, + } + callbackPolicyJSON, err := json.Marshal(callbackPolicy) + if err != nil { + return nil, fmt.Errorf("failed to encode callback policy: %w", err) + } + callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON) + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + // 初始化分片上传 + options := []oss.Option{ + oss.WithContext(ctx), + oss.Expires(uploadSession.Props.ExpireAt), + oss.ForbidOverWrite(true), + oss.ContentType(mimeType), + } + imur, err := handler.bucket.InitiateMultipartUpload(file.Props.SavePath, options...) + if err != nil { + return nil, fmt.Errorf("failed to initialize multipart upload: %w", err) + } + uploadSession.UploadID = imur.UploadID + + // 为每个分片签名上传 URL + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "") + urls := make([]string, chunks.Num()) + ttl := int64(time.Until(uploadSession.Props.ExpireAt).Seconds()) + for chunks.Next() { + err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { + signedURL, err := handler.bucket.SignURL(file.Props.SavePath, oss.HTTPPut, + ttl, + oss.AddParam(partNumberParam, strconv.Itoa(c.Index()+1)), + oss.AddParam(uploadIdParam, imur.UploadID), + oss.ContentType("application/octet-stream")) + if err != nil { + return err + } + + urls[c.Index()] = signedURL + return nil + }) + if err != nil { + return nil, err + } + } + + // 签名完成分片上传的URL + completeURL, err := handler.bucket.SignURL(file.Props.SavePath, oss.HTTPPost, ttl, + oss.ContentType("application/octet-stream"), + oss.AddParam(uploadIdParam, imur.UploadID), + oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)), + oss.SetHeader(completeAllHeader, "yes"), + oss.ForbidOverWrite(true), + oss.AddParam(callbackParam, callbackPolicyEncoded)) + if err != nil { + return nil, err + } + + return &fs.UploadCredential{ + UploadID: imur.UploadID, + UploadURLs: urls, + CompleteURL: completeURL, + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.chunkSize, + }, nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.Props.SavePath}, oss.WithContext(ctx)) +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + mediaMetaExts := handler.policy.Settings.MediaMetaExts + if !handler.policy.Settings.NativeMediaProcessing { + mediaMetaExts = nil + } + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: mediaMetaExts, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: handler.policy.Settings.ThumbExts, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts, + ThumbMaxSize: handler.policy.Settings.ThumbMaxSize, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + if util.ContainsString(supportedImageExt, ext) { + return handler.extractImageMeta(ctx, path) + } + + if util.ContainsString(supportedVideoExt, ext) { + return handler.extractIMMMeta(ctx, path, videoInfoProcess) + } + + if util.ContainsString(supportedAudioExt, ext) { + return handler.extractIMMMeta(ctx, path, audioInfoProcess) + } + + return nil, fmt.Errorf("unsupported media type in oss: %s", ext) +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func (handler *Driver) cancelUpload(imur oss.InitiateMultipartUploadResult) { + if err := handler.bucket.AbortMultipartUpload(imur); err != nil { + handler.l.Warning("failed to abort multipart upload: %s", err) + } +} diff --git a/pkg/filemanager/driver/qiniu/media.go b/pkg/filemanager/driver/qiniu/media.go new file mode 100644 index 00000000..42a0e82e --- /dev/null +++ b/pkg/filemanager/driver/qiniu/media.go @@ -0,0 +1,183 @@ +package qiniu + +import ( + "context" + "encoding/json" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/samber/lo" + "math" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + exifParam = "exif" + avInfoParam = "avinfo" + mediaInfoTTL = time.Duration(10) * time.Minute +) + +var ( + supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff"} +) + +type ( + ImageProp struct { + Value string `json:"val"` + } + ImageInfo map[string]ImageProp + QiniuMediaError struct { + Error string `json:"error"` + Code int `json:"code"` + } +) + +func (handler *Driver) extractAvMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, avInfoParam) + if err != nil { + return nil, err + } + + var avInfo *mediameta.FFProbeMeta + if err := json.Unmarshal([]byte(resp), &avInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + metas := mediameta.ProbeMetaTransform(avInfo) + if artist, ok := avInfo.Format.Tags["artist"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.Artist, + Value: artist, + Type: driver.MediaTypeMusic, + }) + } + + if album, ok := avInfo.Format.Tags["album"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicAlbum, + Value: album, + Type: driver.MediaTypeMusic, + }) + } + + if title, ok := avInfo.Format.Tags["title"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: mediameta.MusicTitle, + Value: title, + Type: driver.MediaTypeMusic, + }) + } + + return metas, nil +} + +func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, exifParam) + if err != nil { + return nil, err + } + + var imageInfo ImageInfo + if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal media info: %w", err) + } + + metas := make([]driver.MediaMeta, 0) + exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) { + return key, value.Value + }) + metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...) + metas = append(metas, parseGpsInfo(imageInfo)...) + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + + return metas, nil +} + +func (handler *Driver) extractMediaInfo(ctx context.Context, path string, param string) (string, error) { + mediaInfoExpire := time.Now().Add(mediaInfoTTL) + ediaInfoUrl := handler.signSourceURL(fmt.Sprintf("%s?%s", path, param), &mediaInfoExpire) + resp, err := handler.httpClient. + Request(http.MethodGet, ediaInfoUrl, nil, request.WithContext(ctx)). + CheckHTTPResponse(http.StatusOK). + GetResponseIgnoreErr() + if err != nil { + return "", unmarshalError(resp, err) + } + + return resp, nil +} + +func unmarshalError(resp string, originErr error) error { + if resp == "" { + return originErr + } + + var err QiniuMediaError + if err := json.Unmarshal([]byte(resp), &err); err != nil { + return fmt.Errorf("failed to unmarshal qiniu error: %w", err) + } + + return fmt.Errorf("qiniu error: %s", err.Error) +} + +func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta { + latitude := imageInfo["GPSLatitude"] // 31, 16.2680820, 0 + longitude := imageInfo["GPSLongitude"] // 120, 42.9103939, 0 + latRef := imageInfo["GPSLatitudeRef"] // N + lonRef := imageInfo["GPSLongitudeRef"] // E + + // Make sure all value exist in map + if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" { + return nil + } + + lat := parseRawGPS(latitude.Value, latRef.Value) + lon := parseRawGPS(longitude.Value, lonRef.Value) + if !math.IsNaN(lat) && !math.IsNaN(lon) { + lat, lng := mediameta.NormalizeGPS(lat, lon) + return []driver.MediaMeta{{ + Key: mediameta.GpsLat, + Value: fmt.Sprintf("%f", lat), + }, { + Key: mediameta.GpsLng, + Value: fmt.Sprintf("%f", lng), + }} + } + + return nil +} + +func parseRawGPS(gpsStr string, ref string) float64 { + elem := strings.Split(gpsStr, ", ") + if len(elem) < 1 { + return 0 + } + + var ( + deg float64 + minutes float64 + seconds float64 + ) + + deg, _ = strconv.ParseFloat(elem[0], 64) + if len(elem) >= 2 { + minutes, _ = strconv.ParseFloat(elem[1], 64) + } + if len(elem) >= 3 { + seconds, _ = strconv.ParseFloat(elem[2], 64) + } + + decimal := deg + minutes/60.0 + seconds/3600.0 + + if ref == "S" || ref == "W" { + return -decimal + } + + return decimal +} diff --git a/pkg/filemanager/driver/qiniu/qiniu.go b/pkg/filemanager/driver/qiniu/qiniu.go new file mode 100644 index 00000000..06cfeeb5 --- /dev/null +++ b/pkg/filemanager/driver/qiniu/qiniu.go @@ -0,0 +1,428 @@ +package qiniu + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/qiniu/go-sdk/v7/auth/qbox" + "github.com/qiniu/go-sdk/v7/storage" + "github.com/samber/lo" + "io" + "net/http" + "net/url" + "os" + "time" +) + +const ( + chunkRetrySleep = time.Duration(5) * time.Second + maxDeleteBatch = 1000 + trafficLimitParam = "X-Qiniu-Traffic-Limit" +) + +var ( + features = &boolset.BooleanSet{} +) + +// Driver 本地策略适配器 +type Driver struct { + policy *ent.StoragePolicy + + mac *qbox.Mac + cfg *storage.Config + bucket *storage.BucketManager + settings setting.Provider + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + httpClient request.Client + + chunkSize int64 +} + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 25 << 20 // 25 MB + } + + mac := qbox.NewMac(policy.AccessKey, policy.SecretKey) + cfg := &storage.Config{UseHTTPS: true} + + driver := &Driver{ + policy: policy, + settings: settings, + chunkSize: chunkSize, + config: config, + l: l, + mime: mime, + mac: mac, + cfg: cfg, + bucket: storage.NewBucketManager(mac, cfg), + httpClient: request.NewClient(config, request.WithLogger(l)), + } + + return driver, nil +} + +// +//// List 列出给定路径下的文件 +//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// base = strings.TrimPrefix(base, "/") +// if base != "" { +// base += "/" +// } +// +// var ( +// delimiter string +// marker string +// objects []storage.ListItem +// commons []string +// ) +// if !recursive { +// delimiter = "/" +// } +// +// for { +// entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles( +// handler.policy.BucketName, +// base, delimiter, marker, 1000) +// if err != nil { +// return nil, err +// } +// objects = append(objects, entries...) +// commons = append(commons, folders...) +// if !hashNext { +// break +// } +// marker = nextMarker +// } +// +// // 处理列取结果 +// res := make([]response.Object, 0, len(objects)+len(commons)) +// // 处理目录 +// for _, object := range commons { +// rel, err := filepath.Rel(base, object) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object), +// RelativePath: filepath.ToSlash(rel), +// Size: 0, +// IsDir: true, +// LastModify: time.Now(), +// }) +// } +// // 处理文件 +// for _, object := range objects { +// rel, err := filepath.Rel(base, object.Key) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(object.Key), +// Source: object.Key, +// RelativePath: filepath.ToSlash(rel), +// Size: uint64(object.Fsize), +// IsDir: false, +// LastModify: time.Unix(object.PutTime/10000000, 0), +// }) +// } +// +// return res, nil +//} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + // 凭证有效期 + credentialTTL := handler.settings.UploadSessionTTL(ctx) + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + + // 生成上传策略 + scope := handler.policy.BucketName + if overwrite { + scope = fmt.Sprintf("%s:%s", handler.policy.BucketName, file.Props.SavePath) + } + putPolicy := storage.PutPolicy{ + // 指定为覆盖策略 + Scope: scope, + SaveKey: file.Props.SavePath, + ForceSaveKey: true, + FsizeLimit: file.Props.Size, + Expires: uint64(time.Now().Add(credentialTTL).Unix()), + } + upToken := putPolicy.UploadToken(handler.mac) + + // 初始化分片上传 + resumeUploader := storage.NewResumeUploaderV2(handler.cfg) + upHost, err := resumeUploader.UpHost(handler.policy.AccessKey, handler.policy.BucketName) + if err != nil { + return fmt.Errorf("failed to get upload host: %w", err) + } + + ret := &storage.InitPartsRet{} + err = resumeUploader.InitParts(ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{ + Max: handler.settings.ChunkRetryLimit(ctx), + Sleep: chunkRetrySleep, + }, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx)) + + parts := make([]*storage.UploadPartsRet, 0, chunks.Num()) + + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + partRet := &storage.UploadPartsRet{} + err := resumeUploader.UploadParts( + ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret.UploadID, + int64(current.Index()+1), "", partRet, content, int(current.Length())) + if err == nil { + parts = append(parts, partRet) + } + return err + } + + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + _ = handler.cancelUpload(upHost, file.Props.SavePath, ret.UploadID, upToken) + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + err = resumeUploader.CompleteParts(ctx, upToken, upHost, nil, handler.policy.BucketName, + file.Props.SavePath, true, ret.UploadID, &storage.RputV2Extra{ + MimeType: mimeType, + Progresses: lo.Map(parts, func(part *storage.UploadPartsRet, i int) storage.UploadPartInfo { + return storage.UploadPartInfo{ + Etag: part.Etag, + PartNumber: int64(i) + 1, + } + }), + }) + if err != nil { + _ = handler.cancelUpload(upHost, file.Props.SavePath, ret.UploadID, upToken) + } + return nil +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + groups := lo.Chunk(files, maxDeleteBatch) + failed := make([]string, 0) + var lastError error + + for index, group := range groups { + handler.l.Debug("Process delete group #%d: %v", index, group) + // 删除文件 + rets, err := handler.bucket.BatchWithContext(ctx, handler.policy.BucketName, lo.Map(group, func(key string, index int) string { + return storage.URIDelete(handler.policy.BucketName, key) + })) + + // 处理删除结果 + if err != nil { + for k, ret := range rets { + if ret.Code != 200 && ret.Code != 612 { + failed = append(failed, group[k]) + lastError = err + } + } + } + } + + if len(failed) > 0 && lastError == nil { + lastError = fmt.Errorf("failed to delete files: %v", failed) + } + + return failed, lastError +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + w, h := handler.settings.ThumbSize(ctx) + + thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d", e.Source(), w, h) + return handler.signSourceURL( + thumb, + expire, + ), nil +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + path := e.Source() + + query := url.Values{} + + // 加入下载相关设置 + if args.IsDownload { + query.Add("attname", args.DisplayName) + } + + if args.Speed > 0 { + // Byte 转换为 bit + args.Speed *= 8 + + // Qiniu 对速度值有范围限制 + if args.Speed < 819200 { + args.Speed = 819200 + } + if args.Speed > 838860800 { + args.Speed = 838860800 + } + query.Add(trafficLimitParam, fmt.Sprintf("%d", args.Speed)) + } + + if len(query) > 0 { + path = path + "?" + query.Encode() + } + + // 取得原始文件地址 + return handler.signSourceURL(path, args.Expire), nil +} + +func (handler *Driver) signSourceURL(path string, expire *time.Time) string { + var sourceURL string + if handler.policy.IsPrivate { + deadline := time.Now().Add(time.Duration(24) * time.Hour * 365 * 20).Unix() + if expire != nil { + deadline = expire.Unix() + } + sourceURL = storage.MakePrivateURL(handler.mac, handler.policy.Settings.ProxyServer, path, deadline) + } else { + sourceURL = storage.MakePublicURL(handler.policy.Settings.ProxyServer, path) + } + return sourceURL +} + +// Token 获取上传策略和认证Token +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + apiUrl := routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeQiniu, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + // 创建上传策略 + putPolicy := storage.PutPolicy{ + Scope: fmt.Sprintf("%s:%s", handler.policy.BucketName, file.Props.SavePath), + CallbackURL: apiUrl, + CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`, + CallbackBodyType: "application/json", + SaveKey: file.Props.SavePath, + ForceSaveKey: true, + FsizeLimit: file.Props.Size, + Expires: uint64(file.Props.ExpireAt.Unix()), + } + + // 初始化分片上传 + upToken := putPolicy.UploadToken(handler.mac) + resumeUploader := storage.NewResumeUploaderV2(handler.cfg) + upHost, err := resumeUploader.UpHost(handler.policy.AccessKey, handler.policy.BucketName) + if err != nil { + return nil, fmt.Errorf("failed to get upload host: %w", err) + } + + ret := &storage.InitPartsRet{} + err = resumeUploader.InitParts(ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret) + if err != nil { + return nil, fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + uploadSession.UploadID = ret.UploadID + return &fs.UploadCredential{ + UploadID: ret.UploadID, + UploadURLs: []string{getUploadUrl(upHost, handler.policy.BucketName, file.Props.SavePath, ret.UploadID)}, + Credential: upToken, + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.chunkSize, + MimeType: mimeType, + }, nil +} + +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + resumeUploader := storage.NewResumeUploaderV2(handler.cfg) + return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0) +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + mediaMetaExts := handler.policy.Settings.MediaMetaExts + if !handler.policy.Settings.NativeMediaProcessing { + mediaMetaExts = nil + } + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: mediaMetaExts, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: handler.policy.Settings.ThumbExts, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts, + ThumbMaxSize: handler.policy.Settings.ThumbMaxSize, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + if util.ContainsString(supportedImageExt, ext) { + return handler.extractImageMeta(ctx, path) + } + + return handler.extractAvMeta(ctx, path) +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func (handler *Driver) cancelUpload(upHost, savePath, uploadId, upToken string) error { + resumeUploader := storage.NewResumeUploaderV2(handler.cfg) + uploadUrl := getUploadUrl(upHost, handler.policy.BucketName, savePath, uploadId) + err := resumeUploader.Client.CallWith(context.Background(), nil, "DELETE", uploadUrl, http.Header{"Authorization": {"UpToken " + upToken}}, nil, 0) + if err != nil { + handler.l.Error("Failed to cancel upload session for %q: %s", savePath, err) + } + return err +} + +func getUploadUrl(upHost, bucket, key, uploadId string) string { + return upHost + "/buckets/" + bucket + "/objects/" + base64.URLEncoding.EncodeToString([]byte(key)) + "/uploads/" + uploadId +} diff --git a/pkg/filemanager/driver/remote/client.go b/pkg/filemanager/driver/remote/client.go new file mode 100644 index 00000000..94016178 --- /dev/null +++ b/pkg/filemanager/driver/remote/client.go @@ -0,0 +1,266 @@ +package remote + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gofrs/uuid" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + OverwriteHeader = constants.CrHeaderPrefix + "Overwrite" + chunkRetrySleep = time.Duration(5) * time.Second +) + +// Client to operate uploading to remote slave server +type Client interface { + // CreateUploadSession creates remote upload session + CreateUploadSession(ctx context.Context, session *fs.UploadSession, overwrite bool) error + // GetUploadURL signs an url for uploading file + GetUploadURL(ctx context.Context, expires time.Time, sessionID string) (string, string, error) + // Upload uploads file to remote server + Upload(ctx context.Context, file *fs.UploadRequest) error + // DeleteUploadSession deletes remote upload session + DeleteUploadSession(ctx context.Context, sessionID string) error + // MediaMeta gets media meta from remote server + MediaMeta(ctx context.Context, src, ext string) ([]driver.MediaMeta, error) + // DeleteFiles deletes files from remote server + DeleteFiles(ctx context.Context, files ...string) ([]string, error) +} + +type DeleteFileRequest struct { + Files []string `json:"files"` +} + +// NewClient creates new Client from given policy +func NewClient(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, config conf.ConfigProvider, l logging.Logger) (Client, error) { + if policy.Edges.Node == nil { + return nil, fmt.Errorf("remote storage policy %d has no node", policy.ID) + } + + authInstance := auth.HMACAuth{[]byte(policy.Edges.Node.SlaveKey)} + serverURL, err := url.Parse(policy.Edges.Node.Server) + if err != nil { + return nil, err + } + + base, _ := url.Parse(constants.APIPrefixSlave) + + return &remoteClient{ + policy: policy, + authInstance: authInstance, + httpClient: request.NewClient( + config, + request.WithEndpoint(serverURL.ResolveReference(base).String()), + request.WithCredential(authInstance, int64(settings.SlaveRequestSignTTL(ctx))), + request.WithSlaveMeta(policy.Edges.Node.ID), + request.WithMasterMeta(settings.SiteBasic(ctx).ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()), + request.WithCorrelationID(), + ), + settings: settings, + l: l, + }, nil +} + +type remoteClient struct { + policy *ent.StoragePolicy + authInstance auth.Auth + httpClient request.Client + settings setting.Provider + l logging.Logger +} + +func (c *remoteClient) Upload(ctx context.Context, file *fs.UploadRequest) error { + ttl := c.settings.UploadSessionTTL(ctx) + session := &fs.UploadSession{ + Props: file.Props.Copy(), + Policy: c.policy, + } + session.Props.UploadSessionID = uuid.Must(uuid.NewV4()).String() + session.Props.ExpireAt = time.Now().Add(ttl) + + // Create upload session + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + if err := c.CreateUploadSession(ctx, session, overwrite); err != nil { + return fmt.Errorf("failed to create upload session: %w", err) + } + + // Initial chunk groups + chunks := chunk.NewChunkGroup(file, c.policy.Settings.ChunkSize, &backoff.ConstantBackoff{ + Max: c.settings.ChunkRetryLimit(ctx), + Sleep: chunkRetrySleep, + }, c.settings.UseChunkBuffer(ctx), c.l, c.settings.TempPath(ctx)) + + uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { + return c.uploadChunk(ctx, session.Props.UploadSessionID, current.Index(), content, overwrite, current.Length()) + } + + // upload chunks + for chunks.Next() { + if err := chunks.Process(uploadFunc); err != nil { + if err := c.DeleteUploadSession(ctx, session.Props.UploadSessionID); err != nil { + c.l.Warning("failed to delete upload session: %s", err) + } + + return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) + } + } + + return nil +} + +func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error { + resp, err := c.httpClient.Request( + "DELETE", + "upload/"+sessionID, + nil, + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (c *remoteClient) DeleteFiles(ctx context.Context, files ...string) ([]string, error) { + req := &DeleteFileRequest{ + Files: files, + } + + reqStr, err := json.Marshal(req) + if err != nil { + return files, fmt.Errorf("failed to marshal delete request: %w", err) + } + + resp, err := c.httpClient.Request( + "DELETE", + "file", + bytes.NewReader(reqStr), + request.WithContext(ctx), + request.WithLogger(logging.FromContext(ctx)), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return files, err + } + + if resp.Code != 0 { + var failed []string + failed = files + if resp.Code == serializer.CodeNotFullySuccess { + resp.GobDecode(&failed) + } + return failed, fmt.Errorf(resp.Error) + } + + return nil, nil +} + +func (c *remoteClient) MediaMeta(ctx context.Context, src, ext string) ([]driver.MediaMeta, error) { + resp, err := c.httpClient.Request( + http.MethodGet, + routes.SlaveMediaMetaRoute(src, ext), + nil, + request.WithContext(ctx), + request.WithLogger(c.l), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return nil, err + } + + if resp.Code != 0 { + return nil, fmt.Errorf(resp.Error) + } + + var metas []driver.MediaMeta + resp.GobDecode(&metas) + return metas, nil +} + +func (c *remoteClient) CreateUploadSession(ctx context.Context, session *fs.UploadSession, overwrite bool) error { + reqBodyEncoded, err := json.Marshal(map[string]interface{}{ + "session": session, + "overwrite": overwrite, + }) + if err != nil { + return err + } + + bodyReader := strings.NewReader(string(reqBodyEncoded)) + resp, err := c.httpClient.Request( + "PUT", + "upload", + bodyReader, + request.WithContext(ctx), + request.WithLogger(c.l), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} + +func (c *remoteClient) GetUploadURL(ctx context.Context, expires time.Time, sessionID string) (string, string, error) { + base, err := url.Parse(c.policy.Edges.Node.Server) + if err != nil { + return "", "", err + } + + req, err := http.NewRequest("POST", routes.SlaveUploadUrl(base, sessionID).String(), nil) + if err != nil { + return "", "", err + } + + req = auth.SignRequest(ctx, c.authInstance, req, &expires) + return req.URL.String(), req.Header["Authorization"][0], nil +} + +func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error { + resp, err := c.httpClient.Request( + "POST", + fmt.Sprintf("upload/%s?chunk=%d", sessionID, index), + chunk, + request.WithContext(ctx), + request.WithTimeout(time.Duration(0)), + request.WithContentLength(size), + request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}), + ).CheckHTTPResponse(200).DecodeResponse() + if err != nil { + return err + } + + if resp.Code != 0 { + return serializer.NewErrorFromResponse(resp) + } + + return nil +} diff --git a/pkg/filemanager/driver/remote/remote.go b/pkg/filemanager/driver/remote/remote.go new file mode 100644 index 00000000..9278655e --- /dev/null +++ b/pkg/filemanager/driver/remote/remote.go @@ -0,0 +1,273 @@ +package remote + +import ( + "context" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "net/url" + "os" + "path" + "time" +) + +var ( + features = &boolset.BooleanSet{} +) + +// Driver 远程存储策略适配器 +type Driver struct { + Client request.Client + Policy *ent.StoragePolicy + AuthInstance auth.Auth + + uploadClient Client + config conf.ConfigProvider + settings setting.Provider +} + +// New initializes a new Driver from policy +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger) (*Driver, error) { + client, err := NewClient(ctx, policy, settings, config, l) + if err != nil { + return nil, err + } + + return &Driver{ + Policy: policy, + Client: request.NewClient(config), + AuthInstance: auth.HMACAuth{[]byte(policy.Edges.Node.SlaveKey)}, + uploadClient: client, + settings: settings, + config: config, + }, nil +} + +//// List 列取文件 +//func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { +// var res []response.Object +// +// reqBody := serializer.ListRequest{ +// Path: path, +// Recursive: recursive, +// } +// reqBodyEncoded, err := json.Marshal(reqBody) +// if err != nil { +// return res, err +// } +// +// // 发送列表请求 +// bodyReader := strings.NewReader(string(reqBodyEncoded)) +// signTTL := model.GetIntSetting("slave_api_timeout", 60) +// resp, err := handler.Client.Request( +// "POST", +// handler.getAPIUrl("list"), +// bodyReader, +// request.WithCredential(handler.AuthInstance, int64(signTTL)), +// request.WithMasterMeta(handler.settings.SiteBasic(ctx).ID, handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()), +// ).CheckHTTPResponse(200).DecodeResponse() +// if err != nil { +// return res, err +// } +// +// // 处理列取结果 +// if resp.Code != 0 { +// return res, errors.New(resp.Error) +// } +// +// if resStr, ok := resp.Data.(string); ok { +// err = json.Unmarshal([]byte(resStr), &res) +// if err != nil { +// return res, err +// } +// } +// +// return res, nil +//} + +// getAPIUrl 获取接口请求地址 +func (handler *Driver) getAPIUrl(scope string, routes ...string) string { + serverURL, err := url.Parse(handler.Policy.Edges.Node.Server) + if err != nil { + return "" + } + var controller *url.URL + + switch scope { + case "delete": + controller, _ = url.Parse("/api/v3/slave/delete") + case "thumb": + controller, _ = url.Parse("/api/v3/slave/thumb") + case "list": + controller, _ = url.Parse("/api/v3/slave/list") + default: + controller = serverURL + } + + for _, r := range routes { + controller.Path = path.Join(controller.Path, r) + } + + return serverURL.ResolveReference(controller).String() +} + +// Open 获取文件内容 +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + //// 尝试获取速度限制 + //speedLimit := 0 + //if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { + // speedLimit = user.Group.SpeedLimit + //} + // + //// 获取文件源地址 + //downloadURL, err := handler.Source(ctx, path, nil, true, int64(speedLimit)) + //if err != nil { + // return nil, err + //} + // + //// 获取文件数据流 + //resp, err := handler.Client.Request( + // "GET", + // downloadURL, + // nil, + // request.WithContext(ctx), + // request.WithTimeout(time.Duration(0)), + // request.WithMasterMeta(handler.settings.SiteBasic(ctx).ID, handler.settings.SiteURL(ctx).String()), + //).CheckHTTPResponse(200).GetRSCloser() + //if err != nil { + // return nil, err + //} + // + //resp.SetFirstFakeChunk() + // + //// 尝试获取文件大小 + //if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { + // resp.SetContentLength(int64(file.Size)) + //} + + return nil, errors.New("not implemented") +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + return handler.uploadClient.Upload(ctx, file) +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + failed, err := handler.uploadClient.DeleteFiles(ctx, files...) + if err != nil { + return failed, err + } + return []string{}, nil +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + serverURL, err := url.Parse(handler.Policy.Edges.Node.Server) + if err != nil { + return "", fmt.Errorf("parse server url failed: %w", err) + } + + thumbURL := routes.SlaveThumbUrl(serverURL, e.Source(), ext) + signedThumbURL, err := auth.SignURI(ctx, handler.AuthInstance, thumbURL.String(), expire) + if err != nil { + return "", err + } + + return signedThumbURL.String(), nil +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + server, err := url.Parse(handler.Policy.Edges.Node.Server) + if err != nil { + return "", err + } + + nodeId := 0 + if handler.config.System().Mode == conf.SlaveMode { + nodeId = handler.Policy.NodeID + } + + base := routes.SlaveFileContentUrl( + server, + e.Source(), + args.DisplayName, + args.IsDownload, + args.Speed, + nodeId, + ) + internalProxyed, err := auth.SignURI(ctx, handler.AuthInstance, base.String(), args.Expire) + if err != nil { + return "", fmt.Errorf("failed to sign internal slave content URL: %w", err) + } + + return internalProxyed.String(), nil +} + +// Token 获取上传策略和认证Token +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + // 在从机端创建上传会话 + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeRemote, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, false); err != nil { + return nil, err + } + + // 获取上传地址 + uploadURL, sign, err := handler.uploadClient.GetUploadURL(ctx, uploadSession.Props.ExpireAt, uploadSession.Props.UploadSessionID) + if err != nil { + return nil, fmt.Errorf("failed to sign upload url: %w", err) + } + + return &fs.UploadCredential{ + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.Policy.Settings.ChunkSize, + UploadURLs: []string{uploadURL}, + Credential: sign, + }, nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Props.UploadSessionID) +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: handler.Policy.Settings.MediaMetaExts, + MediaMetaProxy: handler.Policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: handler.Policy.Settings.ThumbExts, + ThumbProxy: handler.Policy.Settings.ThumbGeneratorProxy, + ThumbMaxSize: handler.Policy.Settings.ThumbMaxSize, + ThumbSupportAllExts: handler.Policy.Settings.ThumbSupportAllExts, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + return handler.uploadClient.MediaMeta(ctx, path, ext) +} diff --git a/pkg/filemanager/driver/s3/s3.go b/pkg/filemanager/driver/s3/s3.go new file mode 100644 index 00000000..bbfb5510 --- /dev/null +++ b/pkg/filemanager/driver/s3/s3.go @@ -0,0 +1,514 @@ +package s3 + +import ( + "context" + "errors" + "fmt" + "io" + "net/url" + "os" + "time" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" +) + +// Driver S3 compatible driver +type Driver struct { + policy *ent.StoragePolicy + chunkSize int64 + + settings setting.Provider + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + + sess *session.Session + svc *s3.S3 +} + +// UploadPolicy S3上传策略 +type UploadPolicy struct { + Expiration string `json:"expiration"` + Conditions []interface{} `json:"conditions"` +} + +// MetaData 文件信息 +type MetaData struct { + Size int64 + Etag string +} + +var ( + features = &boolset.BooleanSet{} +) + +func init() { + boolset.Sets(map[driver.HandlerCapability]bool{ + driver.HandlerCapabilityUploadSentinelRequired: true, + }, features) +} + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + chunkSize := policy.Settings.ChunkSize + if policy.Settings.ChunkSize == 0 { + chunkSize = 25 << 20 // 25 MB + } + + driver := &Driver{ + policy: policy, + settings: settings, + chunkSize: chunkSize, + config: config, + l: l, + mime: mime, + } + + sess, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(policy.AccessKey, policy.SecretKey, ""), + Endpoint: &policy.Server, + Region: &policy.Settings.Region, + S3ForcePathStyle: &policy.Settings.S3ForcePathStyle, + }) + + if err != nil { + return nil, err + } + driver.sess = sess + driver.svc = s3.New(sess) + + return driver, nil +} + +//// List 列出给定路径下的文件 +//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// // 初始化列目录参数 +// base = strings.TrimPrefix(base, "/") +// if base != "" { +// base += "/" +// } +// +// opt := &s3.ListObjectsInput{ +// Bucket: &handler.policy.BucketName, +// Prefix: &base, +// MaxKeys: aws.Int64(1000), +// } +// +// // 是否为递归列出 +// if !recursive { +// opt.Delimiter = aws.String("/") +// } +// +// var ( +// objects []*s3.Object +// commons []*s3.CommonPrefix +// ) +// +// for { +// res, err := handler.svc.ListObjectsWithContext(ctx, opt) +// if err != nil { +// return nil, err +// } +// objects = append(objects, res.Contents...) +// commons = append(commons, res.CommonPrefixes...) +// +// // 如果本次未列取完,则继续使用marker获取结果 +// if *res.IsTruncated { +// opt.Marker = res.NextMarker +// } else { +// break +// } +// } +// +// // 处理列取结果 +// res := make([]response.Object, 0, len(objects)+len(commons)) +// +// // 处理目录 +// for _, object := range commons { +// rel, err := filepath.Rel(*opt.Prefix, *object.Prefix) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(*object.Prefix), +// RelativePath: filepath.ToSlash(rel), +// Size: 0, +// IsDir: true, +// LastModify: time.Now(), +// }) +// } +// // 处理文件 +// for _, object := range objects { +// rel, err := filepath.Rel(*opt.Prefix, *object.Key) +// if err != nil { +// continue +// } +// res = append(res, response.Object{ +// Name: path.Base(*object.Key), +// Source: *object.Key, +// RelativePath: filepath.ToSlash(rel), +// Size: uint64(*object.Size), +// IsDir: false, +// LastModify: time.Now(), +// }) +// } +// +// return res, nil +// +//} + +// Open 打开文件 +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + if !overwrite { + // Check for duplicated file + if _, err := handler.Meta(ctx, file.Props.SavePath); err == nil { + return fs.ErrFileExisted + } + } + + uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) { + u.PartSize = handler.chunkSize + }) + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + _, err := uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: &handler.policy.BucketName, + Key: &file.Props.SavePath, + Body: io.LimitReader(file, file.Props.Size), + ContentType: aws.String(mimeType), + }) + + if err != nil { + return err + } + + return nil +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + failed := make([]string, 0, len(files)) + batchSize := handler.policy.Settings.S3DeleteBatchSize + if batchSize == 0 { + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html + // The request can contain a list of up to 1000 keys that you want to delete. + batchSize = 1000 + } + + var lastErr error + + groups := lo.Chunk(files, batchSize) + for _, group := range groups { + if len(group) == 1 { + // Invoke single file delete API + _, err := handler.svc.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ + Bucket: &handler.policy.BucketName, + Key: &group[0], + }) + + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + // Ignore NoSuchKey error + if aerr.Code() == s3.ErrCodeNoSuchKey { + continue + } + } + failed = append(failed, group[0]) + lastErr = err + } + } else { + // Invoke batch delete API + res, err := handler.svc.DeleteObjects( + &s3.DeleteObjectsInput{ + Bucket: &handler.policy.BucketName, + Delete: &s3.Delete{ + Objects: lo.Map(group, func(s string, i int) *s3.ObjectIdentifier { + return &s3.ObjectIdentifier{Key: &s} + }), + }, + }) + + if err != nil { + failed = append(failed, group...) + lastErr = err + continue + } + + for _, v := range res.Errors { + handler.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key) + failed = append(failed, *v.Key) + } + } + } + + return failed, lastErr + +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + return "", errors.New("not implemented") +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + var contentDescription *string + if args.IsDownload { + encodedFilename := url.PathEscape(args.DisplayName) + contentDescription = aws.String(fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + encodedFilename, encodedFilename)) + } + + req, _ := handler.svc.GetObjectRequest( + &s3.GetObjectInput{ + Bucket: &handler.policy.BucketName, + Key: aws.String(e.Source()), + ResponseContentDisposition: contentDescription, + }) + + ttl := time.Duration(604800) * time.Second // 7 days + if args.Expire != nil { + ttl = time.Until(*args.Expire) + } + signedURL, err := req.Presign(ttl) + if err != nil { + return "", err + } + + // 将最终生成的签名URL域名换成用户自定义的加速域名(如果有) + finalURL, err := url.Parse(signedURL) + if err != nil { + return "", err + } + + // 公有空间替换掉Key及不支持的头 + if !handler.policy.IsPrivate { + finalURL.RawQuery = "" + } + + return finalURL.String(), nil +} + +// Token 获取上传策略和认证Token +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + // Check for duplicated file + if _, err := handler.Meta(ctx, file.Props.SavePath); err == nil { + return nil, fs.ErrFileExisted + } + + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + // 在从机端创建上传会话 + uploadSession.ChunkSize = handler.chunkSize + uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeS3, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + // 创建分片上传 + res, err := handler.svc.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{ + Bucket: &handler.policy.BucketName, + Key: &uploadSession.Props.SavePath, + Expires: &uploadSession.Props.ExpireAt, + ContentType: aws.String(mimeType), + }) + if err != nil { + return nil, fmt.Errorf("failed to create multipart upload: %w", err) + } + + uploadSession.UploadID = *res.UploadId + + // 为每个分片签名上传 URL + chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "") + urls := make([]string, chunks.Num()) + for chunks.Next() { + err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { + signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{ + Bucket: &handler.policy.BucketName, + Key: &uploadSession.Props.SavePath, + PartNumber: aws.Int64(int64(c.Index() + 1)), + ContentLength: aws.Int64(c.Length()), + UploadId: res.UploadId, + }) + + signedURL, err := signedReq.Presign(time.Until(uploadSession.Props.ExpireAt)) + if err != nil { + return err + } + + urls[c.Index()] = signedURL + return nil + }) + if err != nil { + return nil, err + } + } + + // 签名完成分片上传的请求URL + signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{ + Bucket: &handler.policy.BucketName, + Key: &file.Props.SavePath, + UploadId: res.UploadId, + }) + + signedURL, err := signedReq.Presign(time.Until(uploadSession.Props.ExpireAt)) + if err != nil { + return nil, err + } + + // 生成上传凭证 + return &fs.UploadCredential{ + UploadID: *res.UploadId, + UploadURLs: urls, + CompleteURL: signedURL, + SessionID: uploadSession.Props.UploadSessionID, + ChunkSize: handler.chunkSize, + }, nil +} + +// Meta 获取文件信息 +func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) { + res, err := handler.svc.HeadObjectWithContext(ctx, + &s3.HeadObjectInput{ + Bucket: &handler.policy.BucketName, + Key: &path, + }) + + if err != nil { + return nil, err + } + + return &MetaData{ + Size: *res.ContentLength, + Etag: *res.ETag, + }, nil + +} + +// CORS 创建跨域策略 +func (handler *Driver) CORS() error { + rule := s3.CORSRule{ + AllowedMethods: aws.StringSlice([]string{ + "GET", + "POST", + "PUT", + "DELETE", + "HEAD", + }), + AllowedOrigins: aws.StringSlice([]string{"*"}), + AllowedHeaders: aws.StringSlice([]string{"*"}), + ExposeHeaders: aws.StringSlice([]string{"ETag"}), + MaxAgeSeconds: aws.Int64(3600), + } + + _, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{ + Bucket: &handler.policy.BucketName, + CORSConfiguration: &s3.CORSConfiguration{ + CORSRules: []*s3.CORSRule{&rule}, + }, + }) + + return err +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + _, err := handler.svc.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{ + UploadId: &uploadSession.UploadID, + Bucket: &handler.policy.BucketName, + Key: &uploadSession.Props.SavePath, + }) + return err +} + +func (handler *Driver) cancelUpload(key, id *string) { + if _, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + Bucket: &handler.policy.BucketName, + UploadId: id, + Key: key, + }); err != nil { + handler.l.Warning("failed to abort multipart upload: %s", err) + } +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + MaxSourceExpire: time.Duration(604800) * time.Second, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + return nil, errors.New("not implemented") +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + if session.SentinelTaskID == 0 { + return nil + } + + // Make sure uploaded file size is correct + res, err := handler.Meta(ctx, session.Props.SavePath) + if err != nil { + return fmt.Errorf("failed to get uploaded file size: %w", err) + } + + if res.Size != session.Props.Size { + return serializer.NewError( + serializer.CodeMetaMismatch, + fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.Size), + nil, + ) + } + return nil +} + +type Reader struct { + r io.Reader +} + +func (r Reader) Read(p []byte) (int, error) { + return r.r.Read(p) +} diff --git a/pkg/filemanager/driver/upyun/media.go b/pkg/filemanager/driver/upyun/media.go new file mode 100644 index 00000000..bf36acb1 --- /dev/null +++ b/pkg/filemanager/driver/upyun/media.go @@ -0,0 +1,154 @@ +package upyun + +import ( + "context" + "encoding/json" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/mediameta" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/samber/lo" + "math" + "net/http" + "strconv" + "strings" + "time" +) + +var ( + mediaInfoTTL = time.Duration(10) * time.Minute +) + +type ( + ImageInfo struct { + Exif map[string]string `json:"EXIF"` + } +) + +func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) { + resp, err := handler.extractMediaInfo(ctx, path, "!/meta") + if err != nil { + return nil, err + } + + fmt.Println(resp) + + var imageInfo ImageInfo + if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal image info: %w", err) + } + + metas := make([]driver.MediaMeta, 0, len(imageInfo.Exif)) + exifMap := lo.MapEntries(imageInfo.Exif, func(key string, value string) (string, string) { + switch key { + case "0xA434": + key = "LensModel" + } + return key, value + }) + metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...) + metas = append(metas, parseGpsInfo(imageInfo.Exif)...) + + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + + return metas, nil +} + +func (handler *Driver) extractMediaInfo(ctx context.Context, path string, param string) (string, error) { + mediaInfoExpire := time.Now().Add(mediaInfoTTL) + mediaInfoUrl, err := handler.signURL(ctx, path+param, nil, &mediaInfoExpire) + if err != nil { + return "", err + } + + resp, err := handler.httpClient. + Request(http.MethodGet, mediaInfoUrl, nil, request.WithContext(ctx)). + CheckHTTPResponse(http.StatusOK). + GetResponseIgnoreErr() + if err != nil { + return "", unmarshalError(resp, err) + } + + return resp, nil +} + +func unmarshalError(resp string, err error) error { + return fmt.Errorf("upyun error: %s", err) +} + +func parseGpsInfo(imageInfo map[string]string) []driver.MediaMeta { + latitude := imageInfo["GPSLatitude"] // 31/1, 162680820/10000000, 0/1 + longitude := imageInfo["GPSLongitude"] // 120/1, 429103939/10000000, 0/1 + latRef := imageInfo["GPSLatitudeRef"] // N + lonRef := imageInfo["GPSLongitudeRef"] // E + + // Make sure all value exist in map + if latitude == "" || longitude == "" || latRef == "" || lonRef == "" { + return nil + } + + lat := parseRawGPS(latitude, latRef) + lon := parseRawGPS(longitude, lonRef) + if !math.IsNaN(lat) && !math.IsNaN(lon) { + lat, lng := mediameta.NormalizeGPS(lat, lon) + return []driver.MediaMeta{{ + Key: mediameta.GpsLat, + Value: fmt.Sprintf("%f", lat), + }, { + Key: mediameta.GpsLng, + Value: fmt.Sprintf("%f", lng), + }} + } + + return nil +} + +func parseRawGPS(gpsStr string, ref string) float64 { + elem := strings.Split(gpsStr, ",") + if len(elem) < 1 { + return 0 + } + + var ( + deg float64 + minutes float64 + seconds float64 + ) + + deg = getGpsElemValue(elem[0]) + if len(elem) >= 2 { + minutes = getGpsElemValue(elem[1]) + } + if len(elem) >= 3 { + seconds = getGpsElemValue(elem[2]) + } + + decimal := deg + minutes/60.0 + seconds/3600.0 + + if ref == "S" || ref == "W" { + return -decimal + } + + return decimal +} + +func getGpsElemValue(elm string) float64 { + elements := strings.Split(elm, "/") + if len(elements) != 2 { + return 0 + } + + numerator, err := strconv.ParseFloat(elements[0], 64) + if err != nil { + return 0 + } + + denominator, err := strconv.ParseFloat(elements[1], 64) + if err != nil || denominator == 0 { + return 0 + } + + return numerator / denominator +} diff --git a/pkg/filemanager/driver/upyun/upyun.go b/pkg/filemanager/driver/upyun/upyun.go new file mode 100644 index 00000000..2b055cc5 --- /dev/null +++ b/pkg/filemanager/driver/upyun/upyun.go @@ -0,0 +1,382 @@ +package upyun + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gin-gonic/gin" + "github.com/upyun/go-sdk/upyun" + "io" + "net/url" + "os" + "strconv" + "strings" + "time" +) + +type ( + // UploadPolicy 又拍云上传策略 + UploadPolicy struct { + Bucket string `json:"bucket"` + SaveKey string `json:"save-key"` + Expiration int64 `json:"expiration"` + CallbackURL string `json:"notify-url"` + ContentLength uint64 `json:"content-length"` + ContentLengthRange string `json:"content-length-range,omitempty"` + AllowFileType string `json:"allow-file-type,omitempty"` + } + // Driver 又拍云策略适配器 + Driver struct { + policy *ent.StoragePolicy + + up *upyun.UpYun + settings setting.Provider + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + httpClient request.Client + } +) + +var ( + features = &boolset.BooleanSet{} +) + +func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, + config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) { + driver := &Driver{ + policy: policy, + settings: settings, + config: config, + l: l, + mime: mime, + httpClient: request.NewClient(config, request.WithLogger(l)), + up: upyun.NewUpYun(&upyun.UpYunConfig{ + Bucket: policy.BucketName, + Operator: policy.AccessKey, + Password: policy.SecretKey, + }), + } + + return driver, nil +} + +//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { +// base = strings.TrimPrefix(base, "/") +// +// // 用于接受SDK返回对象的chan +// objChan := make(chan *upyun.FileInfo) +// objects := []*upyun.FileInfo{} +// +// // 列取配置 +// listConf := &upyun.GetObjectsConfig{ +// Path: "/" + base, +// ObjectsChan: objChan, +// MaxListTries: 1, +// } +// // 递归列取时不限制递归次数 +// if recursive { +// listConf.MaxListLevel = -1 +// } +// +// // 启动一个goroutine收集列取的对象信 +// wg := &sync.WaitGroup{} +// wg.Add(1) +// go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) { +// defer wg.Done() +// for { +// file, ok := <-input +// if !ok { +// return +// } +// *output = append(*output, file) +// } +// }(objChan, &objects, wg) +// +// up := upyun.NewUpYun(&upyun.UpYunConfig{ +// Bucket: handler.policy.BucketName, +// Operator: handler.policy.AccessKey, +// Password: handler.policy.SecretKey, +// }) +// +// err := up.List(listConf) +// if err != nil { +// return nil, err +// } +// +// wg.Wait() +// +// // 汇总处理列取结果 +// res := make([]response.Object, 0, len(objects)) +// for _, object := range objects { +// res = append(res, response.Object{ +// Name: path.Base(object.Name), +// RelativePath: object.Name, +// Source: path.Join(base, object.Name), +// Size: uint64(object.Size), +// IsDir: object.IsDir, +// LastModify: object.Time, +// }) +// } +// +// return res, nil +//} + +func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) { + return nil, errors.New("not implemented") +} + +// Put 将文件流保存到指定目录 +func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error { + defer file.Close() + + // 是否允许覆盖 + overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite + if !overwrite { + if _, err := handler.up.GetInfo(file.Props.SavePath); err == nil { + return fs.ErrFileExisted + } + } + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + err := handler.up.Put(&upyun.PutObjectConfig{ + Path: file.Props.SavePath, + Reader: file, + Headers: map[string]string{ + "Content-Type": mimeType, + }, + }) + + return err +} + +// Delete 删除一个或多个文件, +// 返回未删除的文件,及遇到的最后一个错误 +func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) { + failed := make([]string, 0) + var lastErr error + + for _, file := range files { + if err := handler.up.Delete(&upyun.DeleteObjectConfig{ + Path: file, + Async: true, + }); err != nil { + filteredErr := strings.ReplaceAll(err.Error(), file, "") + if strings.Contains(filteredErr, "Not found") || + strings.Contains(filteredErr, "NoSuchKey") { + continue + } + + failed = append(failed, file) + lastErr = err + } + } + + return failed, lastErr +} + +// Thumb 获取文件缩略图 +func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) { + w, h := handler.settings.ThumbSize(ctx) + + thumbParam := fmt.Sprintf("!/fwfh/%dx%d", w, h) + thumbURL, err := handler.signURL(ctx, e.Source()+thumbParam, nil, expire) + if err != nil { + return "", err + } + + return thumbURL, nil +} + +// Source 获取外链URL +func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) { + query := url.Values{} + + // 如果是下载文件URL + if args.IsDownload { + query.Add("_upd", args.DisplayName) + } + + return handler.signURL(ctx, e.Source(), &query, args.Expire) +} + +func (handler *Driver) signURL(ctx context.Context, path string, query *url.Values, expire *time.Time) (string, error) { + sourceURL, err := url.Parse(handler.policy.Settings.ProxyServer) + if err != nil { + return "", err + } + + fileKey, err := url.Parse(url.PathEscape(path)) + if err != nil { + return "", err + } + + sourceURL = sourceURL.ResolveReference(fileKey) + if query != nil { + sourceURL.RawQuery = query.Encode() + + } + + if !handler.policy.IsPrivate { + // 未开启Token防盗链时,直接返回 + return sourceURL.String(), nil + } + + etime := time.Now().Add(time.Duration(24) * time.Hour * 365 * 20).Unix() + if expire != nil { + etime = expire.Unix() + } + signStr := fmt.Sprintf( + "%s&%d&%s", + handler.policy.Settings.Token, + etime, + sourceURL.Path, + ) + signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr))) + finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10) + + // 将签名添加到URL中 + q := sourceURL.Query() + q.Add("_upt", finalSign) + sourceURL.RawQuery = q.Encode() + + return sourceURL.String(), nil +} + +// Token 获取上传策略和认证Token +func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) { + if _, err := handler.up.GetInfo(file.Props.SavePath); err == nil { + return nil, fs.ErrFileExisted + } + + // 生成回调地址 + siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)) + apiUrl := routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeUpyun, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String() + + // 上传策略 + putPolicy := UploadPolicy{ + Bucket: handler.policy.BucketName, + SaveKey: file.Props.SavePath, + Expiration: uploadSession.Props.ExpireAt.Unix(), + CallbackURL: apiUrl, + ContentLength: uint64(file.Props.Size), + ContentLengthRange: fmt.Sprintf("0,%d", file.Props.Size), + } + + // 生成上传凭证 + policyJSON, err := json.Marshal(putPolicy) + if err != nil { + return nil, err + } + policyEncoded := base64.StdEncoding.EncodeToString(policyJSON) + + // 生成签名 + elements := []string{"POST", "/" + handler.policy.BucketName, policyEncoded} + signStr := sign(handler.policy.AccessKey, handler.policy.SecretKey, elements) + + mimeType := file.Props.MimeType + if mimeType == "" { + handler.mime.TypeByName(file.Props.Uri.Name()) + } + + return &fs.UploadCredential{ + UploadPolicy: policyEncoded, + UploadURLs: []string{"https://v0.api.upyun.com/" + handler.policy.BucketName}, + Credential: signStr, + MimeType: mimeType, + }, nil +} + +// 取消上传凭证 +func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error { + return nil +} + +func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error { + return nil +} + +func (handler *Driver) Capabilities() *driver.Capabilities { + mediaMetaExts := handler.policy.Settings.MediaMetaExts + if !handler.policy.Settings.NativeMediaProcessing { + mediaMetaExts = nil + } + return &driver.Capabilities{ + StaticFeatures: features, + MediaMetaSupportedExts: mediaMetaExts, + MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy, + ThumbSupportedExts: handler.policy.Settings.ThumbExts, + ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy, + ThumbMaxSize: handler.policy.Settings.ThumbMaxSize, + ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts, + } +} + +func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) { + return handler.extractImageMeta(ctx, path) +} + +func (handler *Driver) LocalPath(ctx context.Context, path string) string { + return "" +} + +func ValidateCallback(c *gin.Context, session *fs.UploadSession) error { + body, err := io.ReadAll(c.Request.Body) + c.Request.Body.Close() + if err != nil { + return fmt.Errorf("failed to read request body: %w", err) + } + + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + contentMD5 := c.Request.Header.Get("Content-Md5") + date := c.Request.Header.Get("Date") + actualSignature := c.Request.Header.Get("Authorization") + actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body)) + if actualContentMD5 != contentMD5 { + return errors.New("MD5 mismatch") + } + + // Compare signature + signature := sign(session.Policy.AccessKey, session.Policy.SecretKey, []string{ + "POST", + c.Request.URL.Path, + date, + contentMD5, + }) + if signature != actualSignature { + return errors.New("Signature not match") + } + + return nil +} + +// Sign 计算又拍云的签名头 +func sign(ak, sk string, elements []string) string { + password := fmt.Sprintf("%x", md5.Sum([]byte(sk))) + mac := hmac.New(sha1.New, []byte(password)) + value := strings.Join(elements, "&") + mac.Write([]byte(value)) + signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil))) + return fmt.Sprintf("UPYUN %s:%s", ak, signStr) +} diff --git a/pkg/filemanager/driver/util.go b/pkg/filemanager/driver/util.go new file mode 100644 index 00000000..14418f5a --- /dev/null +++ b/pkg/filemanager/driver/util.go @@ -0,0 +1,37 @@ +package driver + +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "net/url" + "path" + "strings" +) + +func ApplyProxyIfNeeded(policy *ent.StoragePolicy, srcUrl *url.URL) (*url.URL, error) { + // For custom proxy, generate a new proxyed URL: + // [Proxy Scheme][Proxy Host][Proxy Port][ProxyPath + OriginSrcPath][OriginSrcQuery + ProxyQuery] + if policy.Settings.CustomProxy { + proxy, err := url.Parse(policy.Settings.ProxyServer) + if err != nil { + return nil, fmt.Errorf("failed to parse proxy URL: %w", err) + } + proxy.Path = path.Join(proxy.Path, strings.TrimPrefix(srcUrl.Path, "/")) + q := proxy.Query() + if len(q) == 0 { + proxy.RawQuery = srcUrl.RawQuery + } else { + // Merge query parameters + srcQ := srcUrl.Query() + for k, _ := range srcQ { + q.Set(k, srcQ.Get(k)) + } + + proxy.RawQuery = q.Encode() + } + + srcUrl = proxy + } + + return srcUrl, nil +} diff --git a/pkg/filemanager/fs/dbfs/dbfs.go b/pkg/filemanager/fs/dbfs/dbfs.go new file mode 100644 index 00000000..7d7f7a4f --- /dev/null +++ b/pkg/filemanager/fs/dbfs/dbfs.go @@ -0,0 +1,877 @@ +package dbfs + +import ( + "context" + "errors" + "fmt" + "math/rand" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "github.com/samber/lo" + "golang.org/x/tools/container/intsets" +) + +const ( + ContextHintHeader = constants.CrHeaderPrefix + "Context-Hint" + NavigatorStateCachePrefix = "navigator_state_" + ContextHintTTL = 5 * 60 // 5 minutes + + folderSummaryCachePrefix = "folder_summary_" +) + +type ( + ContextHintCtxKey struct{} + ByPassOwnerCheckCtxKey struct{} +) + +func NewDatabaseFS(u *ent.User, fileClient inventory.FileClient, shareClient inventory.ShareClient, + l logging.Logger, ls lock.LockSystem, settingClient setting.Provider, + storagePolicyClient inventory.StoragePolicyClient, hasher hashid.Encoder, userClient inventory.UserClient, + cache, stateKv cache.Driver) fs.FileSystem { + return &DBFS{ + user: u, + navigators: make(map[string]Navigator), + fileClient: fileClient, + shareClient: shareClient, + l: l, + ls: ls, + settingClient: settingClient, + storagePolicyClient: storagePolicyClient, + hasher: hasher, + userClient: userClient, + cache: cache, + stateKv: stateKv, + } +} + +type DBFS struct { + user *ent.User + navigators map[string]Navigator + fileClient inventory.FileClient + userClient inventory.UserClient + storagePolicyClient inventory.StoragePolicyClient + shareClient inventory.ShareClient + l logging.Logger + ls lock.LockSystem + settingClient setting.Provider + hasher hashid.Encoder + cache cache.Driver + stateKv cache.Driver + mu sync.Mutex +} + +func (f *DBFS) Recycle() { + for _, navigator := range f.navigators { + navigator.Recycle() + } +} + +func (f *DBFS) GetEntity(ctx context.Context, entityID int) (fs.Entity, error) { + if entityID == 0 { + return fs.NewEmptyEntity(f.user), nil + } + + files, _, err := f.fileClient.GetEntitiesByIDs(ctx, []int{entityID}, 0) + if err != nil { + return nil, fmt.Errorf("failed to get entity: %w", err) + } + + if len(files) == 0 { + return nil, fs.ErrEntityNotExist + } + + return fs.NewEntity(files[0]), nil + +} + +func (f *DBFS) List(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.ListFileResult, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // Get navigator + navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityListChildren) + if err != nil { + return nil, nil, err + } + + searchParams := path.SearchParameters() + isSearching := searchParams != nil + + // Validate pagination args + props := navigator.Capabilities(isSearching) + if o.PageSize > props.MaxPageSize { + o.PageSize = props.MaxPageSize + } + + parent, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + return nil, nil, fmt.Errorf("Parent not exist: %w", err) + } + + var hintId *uuid.UUID + if o.generateContextHint { + newHintId := uuid.Must(uuid.NewV4()) + hintId = &newHintId + } + + if o.loadFilePublicMetadata { + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + } + if o.loadFileShareIfOwned && parent != nil && parent.OwnerID() == f.user.ID { + ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true) + } + + var streamCallback func([]*File) + if o.streamListResponseCallback != nil { + streamCallback = func(files []*File) { + o.streamListResponseCallback(parent, lo.Map(files, func(item *File, index int) fs.File { + return item + })) + } + } + + children, err := navigator.Children(ctx, parent, &ListArgs{ + Page: &inventory.PaginationArgs{ + Page: o.FsOption.Page, + PageSize: o.PageSize, + OrderBy: o.OrderBy, + Order: inventory.OrderDirection(o.OrderDirection), + UseCursorPagination: o.useCursorPagination, + PageToken: o.pageToken, + }, + Search: searchParams, + StreamCallback: streamCallback, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to get children: %w", err) + } + + var storagePolicy *ent.StoragePolicy + if parent != nil { + storagePolicy, err = f.getPreferredPolicy(ctx, parent) + if err != nil { + f.l.Warning("Failed to get preferred policy: %v", err) + } + } + + return parent, &fs.ListFileResult{ + Files: lo.Map(children.Files, func(item *File, index int) fs.File { + return item + }), + Props: props, + Pagination: children.Pagination, + ContextHint: hintId, + RecursionLimitReached: children.RecursionLimitReached, + MixedType: children.MixedType, + SingleFileView: children.SingleFileView, + Parent: parent, + StoragePolicy: storagePolicy, + }, nil +} + +func (f *DBFS) Capacity(ctx context.Context, u *ent.User) (*fs.Capacity, error) { + // First, get user's available storage packs + var ( + res = &fs.Capacity{} + ) + + requesterGroup, err := u.Edges.GroupOrErr() + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user's group", err) + } + + res.Used = f.user.Storage + res.Total = requesterGroup.MaxStorage + return res, nil +} + +func (f *DBFS) CreateEntity(ctx context.Context, file fs.File, policy *ent.StoragePolicy, + entityType types.EntityType, req *fs.UploadRequest, opts ...fs.Option) (fs.Entity, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // If uploader specified previous latest version ID (etag), we should check if it's still valid. + if o.previousVersion != "" { + entityId, err := f.hasher.Decode(o.previousVersion, hashid.EntityID) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Unknown version ID", err) + } + + entities, err := file.(*File).Model.Edges.EntitiesOrErr() + if err != nil || entities == nil { + return nil, fmt.Errorf("create entity: previous entities not load") + } + + // File is stale during edit if the latest entity is not the same as the one specified by uploader. + if e := file.PrimaryEntity(); e == nil || e.ID() != entityId { + return nil, fs.ErrStaleVersion + } + } + + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + fileModel := file.(*File).Model + if o.removeStaleEntities { + storageDiff, err := fc.RemoveStaleEntities(ctx, fileModel) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to remove stale entities", err) + } + + tx.AppendStorageDiff(storageDiff) + } + + entity, storageDiff, err := fc.CreateEntity(ctx, fileModel, &inventory.EntityParameters{ + OwnerID: file.(*File).Owner().ID, + EntityType: entityType, + StoragePolicyID: policy.ID, + Source: req.Props.SavePath, + Size: req.Props.Size, + UploadSessionID: uuid.FromStringOrNil(o.UploadRequest.Props.UploadSessionID), + }) + if err != nil { + _ = inventory.Rollback(tx) + + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create entity", err) + } + tx.AppendStorageDiff(storageDiff) + + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit create change", err) + } + + return fs.NewEntity(entity), nil +} + +func (f *DBFS) PatchMetadata(ctx context.Context, path []*fs.URI, metas ...fs.MetadataPatch) error { + ae := serializer.NewAggregateError() + targets := make([]*File, 0, len(path)) + for _, p := range path { + navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityUpdateMetadata, NavigatorCapabilityLockFile) + if err != nil { + ae.Add(p.String(), err) + continue + } + + target, err := f.getFileByPath(ctx, navigator, p) + if err != nil { + ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err)) + continue + } + + // Require Update permission + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID { + return fs.ErrOwnerOnly.WithError(fmt.Errorf("permission denied")) + } + + if target.IsRootFolder() { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move root folder"))) + continue + } + + targets = append(targets, target) + } + + if len(targets) == 0 { + return ae.Aggregate() + } + + // Lock all targets + lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath { + return &LockByPath{value.Uri(true), value, value.Type(), ""} + }) + ls, err := f.acquireByPath(ctx, -1, f.user, true, fs.LockApp(fs.ApplicationUpdateMetadata), lockTargets...) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return err + } + + metadataMap := make(map[string]string) + privateMap := make(map[string]bool) + deleted := make([]string, 0) + for _, meta := range metas { + if meta.Remove { + deleted = append(deleted, meta.Key) + continue + } + metadataMap[meta.Key] = meta.Value + if meta.Private { + privateMap[meta.Key] = meta.Private + } + } + + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + for _, target := range targets { + if err := fc.UpsertMetadata(ctx, target.Model, metadataMap, privateMap); err != nil { + _ = inventory.Rollback(tx) + return fmt.Errorf("failed to upsert metadata: %w", err) + } + + if len(deleted) > 0 { + if err := fc.RemoveMetadata(ctx, target.Model, deleted...); err != nil { + _ = inventory.Rollback(tx) + return fmt.Errorf("failed to remove metadata: %w", err) + } + } + } + + if err := inventory.Commit(tx); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to commit metadata change", err) + } + + return ae.Aggregate() +} + +func (f *DBFS) SharedAddressTranslation(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.URI, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // Get navigator + navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...) + if err != nil { + return nil, nil, err + } + + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + if o.loadFileEntities { + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + } + + uriTranslation := func(target *File, rebase bool) (fs.File, *fs.URI, error) { + // Translate shared address to real address + metadata := target.Metadata() + if metadata == nil { + if err := f.fileClient.QueryMetadata(ctx, target.Model); err != nil { + return nil, nil, fmt.Errorf("failed to query metadata: %w", err) + } + metadata = target.Metadata() + } + redirect, ok := metadata[MetadataSharedRedirect] + if !ok { + return nil, nil, fmt.Errorf("missing metadata %s in symbolic folder %s", MetadataSharedRedirect, path) + } + + redirectUri, err := fs.NewUriFromString(redirect) + if err != nil { + return nil, nil, fmt.Errorf("invalid redirect uri %s in symbolic folder %s", redirect, path) + } + newUri := redirectUri + if rebase { + newUri = redirectUri.Rebase(path, target.Uri(false)) + } + return f.SharedAddressTranslation(ctx, newUri, opts...) + } + + target, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + if errors.Is(err, ErrSymbolicFolderFound) && target.Type() == types.FileTypeFolder { + return uriTranslation(target, true) + } + + if !ent.IsNotFound(err) { + return nil, nil, fmt.Errorf("failed to get target file: %w", err) + } + + // Request URI does not exist, return most recent ancestor + return target, path, err + } + + if target.IsSymbolic() { + return uriTranslation(target, false) + } + + return target, path, nil +} + +func (f *DBFS) Get(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // Get navigator + navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...) + if err != nil { + return nil, err + } + + if o.loadFilePublicMetadata || o.extendedInfo { + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + } + + if o.loadFileEntities || o.extendedInfo || o.loadFolderSummary { + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + } + + if o.loadFileShareIfOwned { + ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true) + } + + if o.loadEntityUser { + ctx = context.WithValue(ctx, inventory.LoadEntityUser{}, true) + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + return nil, fmt.Errorf("failed to get target file: %w", err) + } + + if o.extendedInfo && target != nil { + extendedInfo := &fs.FileExtendedInfo{ + StorageUsed: target.SizeUsed(), + EntityStoragePolicies: make(map[int]*ent.StoragePolicy), + } + policyID := target.PolicyID() + if policyID > 0 { + policy, err := f.storagePolicyClient.GetPolicyByID(ctx, policyID) + if err == nil { + extendedInfo.StoragePolicy = policy + } + } + + target.FileExtendedInfo = extendedInfo + if target.OwnerID() == f.user.ID || f.user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) { + target.FileExtendedInfo.Shares = target.Model.Edges.Shares + } + + entities := target.Entities() + for _, entity := range entities { + if _, ok := extendedInfo.EntityStoragePolicies[entity.PolicyID()]; !ok { + policy, err := f.storagePolicyClient.GetPolicyByID(ctx, entity.PolicyID()) + if err != nil { + return nil, fmt.Errorf("failed to get policy: %w", err) + } + + extendedInfo.EntityStoragePolicies[entity.PolicyID()] = policy + } + } + } + + // Calculate folder summary if requested + if o.loadFolderSummary && target != nil && target.Type() == types.FileTypeFolder { + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID { + return nil, fs.ErrOwnerOnly + } + + // first, try to load from cache + summary, ok := f.cache.Get(fmt.Sprintf("%s%d", folderSummaryCachePrefix, target.ID())) + if ok { + summaryTyped := summary.(fs.FolderSummary) + target.FileFolderSummary = &summaryTyped + } else { + // cache miss, walk the folder to get the summary + newSummary := &fs.FolderSummary{Completed: true} + if f.user.Edges.Group == nil { + return nil, fmt.Errorf("user group not loaded") + } + limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1) + + // disable load metadata to speed up + ctxWalk := context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, false) + if err := navigator.Walk(ctxWalk, []*File{target}, limit, intsets.MaxInt, func(files []*File, l int) error { + for _, file := range files { + if file.ID() == target.ID() { + continue + } + if file.Type() == types.FileTypeFile { + newSummary.Files++ + } else { + newSummary.Folders++ + } + + newSummary.Size += file.SizeUsed() + } + return nil + }); err != nil { + if !errors.Is(err, ErrFileCountLimitedReached) { + return nil, fmt.Errorf("failed to walk: %w", err) + } + + newSummary.Completed = false + } + + // cache the summary + newSummary.CalculatedAt = time.Now() + f.cache.Set(fmt.Sprintf("%s%d", folderSummaryCachePrefix, target.ID()), newSummary, f.settingClient.FolderPropsCacheTTL(ctx)) + target.FileFolderSummary = newSummary + } + } + + if target == nil { + return nil, fmt.Errorf("cannot get root file with nil root") + } + + return target, nil +} + +func (f *DBFS) CheckCapability(ctx context.Context, uri *fs.URI, opts ...fs.Option) error { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // Get navigator + _, err := f.getNavigator(ctx, uri, o.requiredCapabilities...) + if err != nil { + return err + } + + return nil +} + +func (f *DBFS) Walk(ctx context.Context, path *fs.URI, depth int, walk fs.WalkFunc, opts ...fs.Option) error { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + if o.loadFilePublicMetadata { + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + } + + if o.loadFileEntities { + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + } + + // Get navigator + navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...) + if err != nil { + return err + } + + target, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + return err + } + + // Require Read permission + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID { + return fs.ErrOwnerOnly + } + + // Walk + if f.user.Edges.Group == nil { + return fmt.Errorf("user group not loaded") + } + limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1) + + if err := navigator.Walk(ctx, []*File{target}, limit, depth, func(files []*File, l int) error { + for _, file := range files { + if err := walk(file, l); err != nil { + return err + } + } + return nil + }); err != nil { + return fmt.Errorf("failed to walk: %w", err) + } + + return nil +} + +func (f *DBFS) ExecuteNavigatorHooks(ctx context.Context, hookType fs.HookType, file fs.File) error { + navigator, err := f.getNavigator(ctx, file.Uri(false)) + if err != nil { + return err + } + + if dbfsFile, ok := file.(*File); ok { + return navigator.ExecuteHook(ctx, hookType, dbfsFile) + } + + return nil +} + +// createFile creates a file with given name and type under given parent folder +func (f *DBFS) createFile(ctx context.Context, parent *File, name string, fileType types.FileType, o *dbfsOption) (*File, error) { + createFileArgs := &inventory.CreateFileParameters{ + FileType: fileType, + Name: name, + MetadataPrivateMask: make(map[string]bool), + Metadata: make(map[string]string), + IsSymbolic: o.isSymbolicLink, + } + + if o.Metadata != nil { + for k, v := range o.Metadata { + createFileArgs.Metadata[k] = v + } + } + + if o.preferredStoragePolicy != nil { + createFileArgs.StoragePolicyID = o.preferredStoragePolicy.ID + } else { + // get preferred storage policy + policy, err := f.getPreferredPolicy(ctx, parent) + if err != nil { + return nil, err + } + + createFileArgs.StoragePolicyID = policy.ID + } + + if o.UploadRequest != nil { + createFileArgs.EntityParameters = &inventory.EntityParameters{ + EntityType: types.EntityTypeVersion, + Source: o.UploadRequest.Props.SavePath, + Size: o.UploadRequest.Props.Size, + ModifiedAt: o.UploadRequest.Props.LastModified, + UploadSessionID: uuid.FromStringOrNil(o.UploadRequest.Props.UploadSessionID), + } + } + + // Start transaction to create files + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + file, entity, storageDiff, err := fc.CreateFile(ctx, parent.Model, createFileArgs) + if err != nil { + _ = inventory.Rollback(tx) + if ent.IsConstraintError(err) { + return nil, fs.ErrFileExisted.WithError(err) + } + + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create file", err) + } + + tx.AppendStorageDiff(storageDiff) + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit create change", err) + } + + file.SetEntities([]*ent.Entity{entity}) + return newFile(parent, file), nil +} + +// getPreferredPolicy tries to get the preferred storage policy for the given file. +func (f *DBFS) getPreferredPolicy(ctx context.Context, file *File) (*ent.StoragePolicy, error) { + ownerGroup := file.Owner().Edges.Group + if ownerGroup == nil { + return nil, fmt.Errorf("owner group not loaded") + } + + groupPolicy, err := f.storagePolicyClient.GetByGroup(ctx, ownerGroup) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get available storage policies", err) + } + + return groupPolicy, nil +} + +func (f *DBFS) getFileByPath(ctx context.Context, navigator Navigator, path *fs.URI) (*File, error) { + file, err := navigator.To(ctx, path) + if err != nil && errors.Is(err, ErrFsNotInitialized) { + // Initialize file system for user if root folder does not exist. + uid := path.ID(hashid.EncodeUserID(f.hasher, f.user.ID)) + uidInt, err := f.hasher.Decode(uid, hashid.UserID) + if err != nil { + return nil, fmt.Errorf("failed to decode user ID: %w", err) + } + + if err := f.initFs(ctx, uidInt); err != nil { + return nil, fmt.Errorf("failed to initialize file system: %w", err) + } + return navigator.To(ctx, path) + } + + return file, err +} + +// initFs initializes the file system for the user. +func (f *DBFS) initFs(ctx context.Context, uid int) error { + f.l.Info("Initialize database file system for user %q", f.user.Email) + _, err := f.fileClient.CreateFolder(ctx, nil, + &inventory.CreateFolderParameters{ + Owner: uid, + Name: inventory.RootFolderName, + }) + if err != nil { + return fmt.Errorf("failed to create root folder: %w", err) + } + + return nil +} + +func (f *DBFS) getNavigator(ctx context.Context, path *fs.URI, requiredCapabilities ...NavigatorCapability) (Navigator, error) { + pathFs := path.FileSystem() + config := f.settingClient.DBFS(ctx) + navigatorId := f.navigatorId(path) + var ( + res Navigator + ) + f.mu.Lock() + defer f.mu.Unlock() + if navigator, ok := f.navigators[navigatorId]; ok { + res = navigator + } else { + var n Navigator + switch pathFs { + case constants.FileSystemMy: + n = NewMyNavigator(f.user, f.fileClient, f.userClient, f.l, config, f.hasher) + case constants.FileSystemShare: + n = NewShareNavigator(f.user, f.fileClient, f.shareClient, f.l, config, f.hasher) + case constants.FileSystemTrash: + n = NewTrashNavigator(f.user, f.fileClient, f.l, config, f.hasher) + case constants.FileSystemSharedWithMe: + n = NewSharedWithMeNavigator(f.user, f.fileClient, f.l, config, f.hasher) + default: + return nil, fmt.Errorf("unknown file system %q", pathFs) + } + + // retrieve state if context hint is provided + if stateID, ok := ctx.Value(ContextHintCtxKey{}).(uuid.UUID); ok && stateID != uuid.Nil { + cacheKey := NavigatorStateCachePrefix + stateID.String() + "_" + navigatorId + if stateRaw, ok := f.stateKv.Get(cacheKey); ok { + if err := n.RestoreState(stateRaw.(State)); err != nil { + f.l.Warning("Failed to restore state for navigator %q: %s", navigatorId, err) + } else { + f.l.Info("Navigator %q restored state (%q) successfully", navigatorId, stateID) + } + } else { + // State expire, refresh it + n.PersistState(f.stateKv, cacheKey) + } + } + + f.navigators[navigatorId] = n + res = n + } + + // Check fs capabilities + capabilities := res.Capabilities(false).Capability + for _, capability := range requiredCapabilities { + if !capabilities.Enabled(int(capability)) { + return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("action %q is not supported under current fs", capability)) + } + } + + return res, nil +} + +func (f *DBFS) navigatorId(path *fs.URI) string { + uidHashed := hashid.EncodeUserID(f.hasher, f.user.ID) + switch path.FileSystem() { + case constants.FileSystemMy: + return fmt.Sprintf("%s/%s/%d", constants.FileSystemMy, path.ID(uidHashed), f.user.ID) + case constants.FileSystemShare: + return fmt.Sprintf("%s/%s/%d", constants.FileSystemShare, path.ID(uidHashed), f.user.ID) + case constants.FileSystemTrash: + return fmt.Sprintf("%s/%s", constants.FileSystemTrash, path.ID(uidHashed)) + default: + return fmt.Sprintf("%s/%s/%d", path.FileSystem(), path.ID(uidHashed), f.user.ID) + } +} + +// generateSavePath generates the physical save path for the upload request. +func generateSavePath(policy *ent.StoragePolicy, req *fs.UploadRequest, user *ent.User) string { + baseTable := map[string]string{ + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), + "{randomnum2}": strconv.Itoa(rand.Intn(2)), + "{randomnum3}": strconv.Itoa(rand.Intn(3)), + "{randomnum4}": strconv.Itoa(rand.Intn(4)), + "{randomnum8}": strconv.Itoa(rand.Intn(8)), + "{uid}": strconv.Itoa(user.ID), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + "{year}": time.Now().Format("2006"), + "{month}": time.Now().Format("01"), + "{day}": time.Now().Format("02"), + "{hour}": time.Now().Format("15"), + "{minute}": time.Now().Format("04"), + "{second}": time.Now().Format("05"), + } + + dirRule := policy.DirNameRule + dirRule = filepath.ToSlash(dirRule) + dirRule = util.Replace(baseTable, dirRule) + dirRule = util.Replace(map[string]string{ + "{path}": req.Props.Uri.Dir() + fs.Separator, + }, dirRule) + + originName := req.Props.Uri.Name() + nameTable := map[string]string{ + "{originname}": originName, + "{ext}": filepath.Ext(originName), + "{originname_without_ext}": strings.TrimSuffix(originName, filepath.Ext(originName)), + "{uuid}": uuid.Must(uuid.NewV4()).String(), + } + + nameRule := policy.FileNameRule + nameRule = util.Replace(baseTable, nameRule) + nameRule = util.Replace(nameTable, nameRule) + + return path.Join(path.Clean(dirRule), nameRule) +} + +func canMoveOrCopyTo(src, dst *fs.URI, isCopy bool) bool { + if isCopy { + return src.FileSystem() == dst.FileSystem() && src.FileSystem() == constants.FileSystemMy + } else { + switch src.FileSystem() { + case constants.FileSystemMy: + return dst.FileSystem() == constants.FileSystemMy || dst.FileSystem() == constants.FileSystemTrash + case constants.FileSystemTrash: + return dst.FileSystem() == constants.FileSystemMy + + } + } + + return false +} + +func allAncestors(targets []*File) []*ent.File { + return lo.Map( + lo.UniqBy( + lo.FlatMap(targets, func(value *File, index int) []*File { + return value.Ancestors() + }), + func(item *File) int { + return item.ID() + }, + ), + func(item *File, index int) *ent.File { + return item.Model + }, + ) +} + +func WithBypassOwnerCheck(ctx context.Context) context.Context { + return context.WithValue(ctx, ByPassOwnerCheckCtxKey{}, true) +} diff --git a/pkg/filemanager/fs/dbfs/file.go b/pkg/filemanager/fs/dbfs/file.go new file mode 100644 index 00000000..a6ea2224 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/file.go @@ -0,0 +1,335 @@ +package dbfs + +import ( + "encoding/gob" + "path" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/samber/lo" +) + +func init() { + gob.Register(File{}) + gob.Register(shareNavigatorState{}) + gob.Register(map[string]*File{}) + gob.Register(map[int]*File{}) +} + +var filePool = &sync.Pool{ + New: func() any { + return &File{ + Children: make(map[string]*File), + } + }, +} + +type ( + File struct { + Model *ent.File + Children map[string]*File + Parent *File + Path [2]*fs.URI + OwnerModel *ent.User + IsUserRoot bool + CapabilitiesBs *boolset.BooleanSet + FileExtendedInfo *fs.FileExtendedInfo + FileFolderSummary *fs.FolderSummary + + mu *sync.Mutex + } +) + +const ( + MetadataSysPrefix = "sys:" + MetadataUploadSessionPrefix = MetadataSysPrefix + "upload_session" + MetadataUploadSessionID = MetadataUploadSessionPrefix + "_id" + MetadataSharedRedirect = MetadataSysPrefix + "shared_redirect" + MetadataRestoreUri = MetadataSysPrefix + "restore_uri" + MetadataExpectedCollectTime = MetadataSysPrefix + "expected_collect_time" + + ThumbMetadataPrefix = "thumb:" + ThumbDisabledKey = ThumbMetadataPrefix + "disabled" + + pathIndexRoot = 0 + pathIndexUser = 1 +) + +func (f *File) Name() string { + return f.Model.Name +} + +func (f *File) IsNil() bool { + return f == nil +} + +func (f *File) DisplayName() string { + if uri, ok := f.Metadata()[MetadataRestoreUri]; ok { + restoreUri, err := fs.NewUriFromString(uri) + if err != nil { + return f.Name() + } + + return path.Base(restoreUri.Path()) + } + + return f.Name() +} + +func (f *File) CanHaveChildren() bool { + return f.Type() == types.FileTypeFolder && !f.IsSymbolic() +} + +func (f *File) Ext() string { + return util.Ext(f.Name()) +} + +func (f *File) ID() int { + return f.Model.ID +} + +func (f *File) IsSymbolic() bool { + return f.Model.IsSymbolic +} + +func (f *File) Type() types.FileType { + return types.FileType(f.Model.Type) +} + +func (f *File) Size() int64 { + return f.Model.Size +} + +func (f *File) SizeUsed() int64 { + return lo.SumBy(f.Entities(), func(item fs.Entity) int64 { + return item.Size() + }) +} + +func (f *File) UpdatedAt() time.Time { + return f.Model.UpdatedAt +} + +func (f *File) CreatedAt() time.Time { + return f.Model.CreatedAt +} + +func (f *File) ExtendedInfo() *fs.FileExtendedInfo { + return f.FileExtendedInfo +} + +func (f *File) Owner() *ent.User { + parent := f + for parent != nil { + if parent.OwnerModel != nil { + return parent.OwnerModel + } + parent = parent.Parent + } + + return nil +} + +func (f *File) OwnerID() int { + return f.Model.OwnerID +} + +func (f *File) Shared() bool { + return len(f.Model.Edges.Shares) > 0 +} + +func (f *File) Metadata() map[string]string { + if f.Model.Edges.Metadata == nil { + return nil + } + return lo.Associate(f.Model.Edges.Metadata, func(item *ent.Metadata) (string, string) { + return item.Name, item.Value + }) +} + +// Uri returns the URI of the file. +// If isRoot is true, the URI will be returned from owner's view. +// Otherwise, the URI will be returned from user's view. +func (f *File) Uri(isRoot bool) *fs.URI { + index := 1 + if isRoot { + index = 0 + } + if f.Path[index] != nil || f.Parent == nil { + return f.Path[index] + } + + // Find the root file + elements := make([]string, 0) + parent := f + for parent.Parent != nil && parent.Path[index] == nil { + elements = append([]string{parent.Name()}, elements...) + parent = parent.Parent + } + + if parent.Path[index] == nil { + return nil + } + + return parent.Path[index].Join(elements...) +} + +// UserRoot return the root file from user's view. +func (f *File) UserRoot() *File { + root := f + for root != nil && !root.IsUserRoot { + root = root.Parent + } + + return root +} + +// Root return the root file from owner's view. +func (f *File) Root() *File { + root := f + for root.Parent != nil { + root = root.Parent + } + + return root +} + +// RootUri return the URI of the user root file under owner's view. +func (f *File) RootUri() *fs.URI { + return f.UserRoot().Uri(true) +} + +func (f *File) Replace(model *ent.File) *File { + f.mu.Lock() + delete(f.Parent.Children, f.Model.Name) + f.mu.Unlock() + + defer f.Recycle() + replaced := newFile(f.Parent, model) + if f.IsRootFile() { + // If target is a root file, the user path should remain the same. + replaced.Path[pathIndexUser] = f.Path[pathIndexUser] + } + + return replaced +} + +// Ancestors return all ancestors of the file, until the owner root is reached. +func (f *File) Ancestors() []*File { + return f.AncestorsChain()[1:] +} + +// AncestorsChain return all ancestors of the file (including itself), until the owner root is reached. +func (f *File) AncestorsChain() []*File { + ancestors := make([]*File, 0) + parent := f + for parent != nil { + ancestors = append(ancestors, parent) + parent = parent.Parent + } + + return ancestors +} + +func (f *File) PolicyID() int { + root := f + return root.Model.StoragePolicyFiles +} + +// IsRootFolder return true if the file is the root folder under user's view. +func (f *File) IsRootFolder() bool { + return f.Type() == types.FileTypeFolder && f.IsRootFile() +} + +// IsRootFile return true if the file is the root file under user's view. +func (f *File) IsRootFile() bool { + uri := f.Uri(false) + p := uri.Path() + return f.Model.Name == inventory.RootFolderName || p == fs.Separator || p == "" +} + +func (f *File) Entities() []fs.Entity { + return lo.Map(f.Model.Edges.Entities, func(item *ent.Entity, index int) fs.Entity { + return fs.NewEntity(item) + }) +} + +func (f *File) PrimaryEntity() fs.Entity { + primary, _ := lo.Find(f.Model.Edges.Entities, func(item *ent.Entity) bool { + return item.Type == int(types.EntityTypeVersion) && item.ID == f.Model.PrimaryEntity + }) + if primary != nil { + return fs.NewEntity(primary) + } + + return fs.NewEmptyEntity(f.Owner()) +} + +func (f *File) PrimaryEntityID() int { + return f.Model.PrimaryEntity +} + +func (f *File) FolderSummary() *fs.FolderSummary { + return f.FileFolderSummary +} + +func (f *File) Capabilities() *boolset.BooleanSet { + return f.CapabilitiesBs +} + +func newFile(parent *File, model *ent.File) *File { + f := filePool.Get().(*File) + f.Model = model + + if parent != nil { + f.Parent = parent + parent.mu.Lock() + parent.Children[model.Name] = f + if parent.Path[pathIndexUser] != nil { + f.Path[pathIndexUser] = parent.Path[pathIndexUser].Join(model.Name) + } + + if parent.Path[pathIndexRoot] != nil { + f.Path[pathIndexRoot] = parent.Path[pathIndexRoot].Join(model.Name) + } + + f.CapabilitiesBs = parent.CapabilitiesBs + f.mu = parent.mu + parent.mu.Unlock() + } else { + f.mu = &sync.Mutex{} + } + + return f +} + +func newParentFile(parent *ent.File, child *File) *File { + newParent := newFile(nil, parent) + newParent.Children[child.Name()] = child + child.Parent = newParent + newParent.mu = child.mu + return newParent +} + +func (f *File) Recycle() { + for _, child := range f.Children { + child.Recycle() + } + + f.Model = nil + f.Children = make(map[string]*File) + f.Path[0] = nil + f.Path[1] = nil + f.Parent = nil + f.OwnerModel = nil + f.IsUserRoot = false + f.mu = nil + + filePool.Put(f) +} diff --git a/pkg/filemanager/fs/dbfs/global.go b/pkg/filemanager/fs/dbfs/global.go new file mode 100644 index 00000000..f9c0f70d --- /dev/null +++ b/pkg/filemanager/fs/dbfs/global.go @@ -0,0 +1,55 @@ +package dbfs + +import ( + "context" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/samber/lo" +) + +func (f *DBFS) StaleEntities(ctx context.Context, entities ...int) ([]fs.Entity, error) { + res, err := f.fileClient.StaleEntities(ctx, entities...) + if err != nil { + return nil, err + } + + return lo.Map(res, func(e *ent.Entity, i int) fs.Entity { + return fs.NewEntity(e) + }), nil +} + +func (f *DBFS) AllFilesInTrashBin(ctx context.Context, opts ...fs.Option) (*fs.ListFileResult, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + navigator, err := f.getNavigator(ctx, newTrashUri(""), NavigatorCapabilityListChildren) + if err != nil { + return nil, err + } + + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + children, err := navigator.Children(ctx, nil, &ListArgs{ + Page: &inventory.PaginationArgs{ + Page: o.FsOption.Page, + PageSize: o.PageSize, + OrderBy: o.OrderBy, + Order: inventory.OrderDirection(o.OrderDirection), + UseCursorPagination: o.useCursorPagination, + PageToken: o.pageToken, + }, + }) + if err != nil { + return nil, err + } + + return &fs.ListFileResult{ + Files: lo.Map(children.Files, func(item *File, index int) fs.File { + return item + }), + Pagination: children.Pagination, + RecursionLimitReached: children.RecursionLimitReached, + }, nil +} diff --git a/pkg/filemanager/fs/dbfs/lock.go b/pkg/filemanager/fs/dbfs/lock.go new file mode 100644 index 00000000..3ebdd26d --- /dev/null +++ b/pkg/filemanager/fs/dbfs/lock.go @@ -0,0 +1,325 @@ +package dbfs + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/samber/lo" +) + +type ( + LockSession struct { + Tokens map[string]string + TokenStack [][]string + } + + LockByPath struct { + Uri *fs.URI + ClosestAncestor *File + Type types.FileType + Token string + } + + AlwaysIncludeTokenCtx struct{} +) + +func (f *DBFS) ConfirmLock(ctx context.Context, ancestor fs.File, uri *fs.URI, token ...string) (func(), fs.LockSession, error) { + session := LockSessionFromCtx(ctx) + lockUri := ancestor.RootUri().JoinRaw(uri.PathTrimmed()) + ns, root, lKey := lockTupleFromUri(lockUri, f.user, f.hasher) + lc := lock.LockInfo{ + Ns: ns, + Root: root, + Token: token, + } + + // Skip if already locked in current session + if _, ok := session.Tokens[lKey]; ok { + return func() {}, session, nil + } + + release, tokenHit, err := f.ls.Confirm(time.Now(), lc) + if err != nil { + return nil, nil, err + } + + session.Tokens[lKey] = tokenHit + stackIndex := len(session.TokenStack) - 1 + session.TokenStack[stackIndex] = append(session.TokenStack[stackIndex], lKey) + return release, session, nil +} + +func (f *DBFS) Lock(ctx context.Context, d time.Duration, requester *ent.User, zeroDepth bool, application lock.Application, + uri *fs.URI, token string) (fs.LockSession, error) { + // Get navigator + navigator, err := f.getNavigator(ctx, uri, NavigatorCapabilityLockFile) + if err != nil { + return nil, err + } + + ancestor, err := f.getFileByPath(ctx, navigator, uri) + if err != nil && !ent.IsNotFound(err) { + return nil, fmt.Errorf("failed to get ancestor: %w", err) + } + + if ancestor.IsRootFolder() && ancestor.Uri(false).IsSame(uri, hashid.EncodeUserID(f.hasher, f.user.ID)) { + return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot lock root folder")) + } + + // Lock require create or update permission + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && ancestor.Owner().ID != requester.ID { + return nil, fs.ErrOwnerOnly + } + + t := types.FileTypeFile + if ancestor.Uri(false).IsSame(uri, hashid.EncodeUserID(f.hasher, f.user.ID)) { + t = ancestor.Type() + } + lr := &LockByPath{ + Uri: ancestor.RootUri().JoinRaw(uri.PathTrimmed()), + ClosestAncestor: ancestor, + Type: t, + Token: token, + } + ls, err := f.acquireByPath(ctx, d, requester, zeroDepth, application, lr) + if err != nil { + return nil, err + } + + return ls, nil +} + +func (f *DBFS) Unlock(ctx context.Context, tokens ...string) error { + return f.ls.Unlock(time.Now(), tokens...) +} + +func (f *DBFS) Refresh(ctx context.Context, d time.Duration, token string) (lock.LockDetails, error) { + return f.ls.Refresh(time.Now(), d, token) +} + +func (f *DBFS) acquireByPath(ctx context.Context, duration time.Duration, + requester *ent.User, zeroDepth bool, application lock.Application, locks ...*LockByPath) (*LockSession, error) { + session := LockSessionFromCtx(ctx) + + // Prepare lock details for each file + lockDetails := make([]lock.LockDetails, 0, len(locks)) + lockedRequest := make([]*LockByPath, 0, len(locks)) + for _, l := range locks { + ns, root, lKey := lockTupleFromUri(l.Uri, f.user, f.hasher) + ld := lock.LockDetails{ + Owner: lock.Owner{ + Application: application, + }, + Ns: ns, + Root: root, + ZeroDepth: zeroDepth, + Duration: duration, + Type: l.Type, + Token: l.Token, + } + + // Skip if already locked in current session + if _, ok := session.Tokens[lKey]; ok { + continue + } + + lockDetails = append(lockDetails, ld) + lockedRequest = append(lockedRequest, l) + } + + // Acquire lock + tokens, err := f.ls.Create(time.Now(), lockDetails...) + if len(tokens) > 0 { + for i, token := range tokens { + key := lockDetails[i].Key() + session.Tokens[key] = token + stackIndex := len(session.TokenStack) - 1 + session.TokenStack[stackIndex] = append(session.TokenStack[stackIndex], key) + } + } + + if err != nil { + var conflicts lock.ConflictError + if errors.As(err, &conflicts) { + // Conflict with existing lock, generate user-friendly error message + conflicts = lo.Map(conflicts, func(c *lock.ConflictDetail, index int) *lock.ConflictDetail { + lr := lockedRequest[c.Index] + if lr.ClosestAncestor.Root().Model.OwnerID == requester.ID { + // Add absolute path for owner issued lock request + c.Path = newMyUri().JoinRaw(c.Path).String() + return c + } + + // Hide token for non-owner requester + if v, ok := ctx.Value(AlwaysIncludeTokenCtx{}).(bool); !ok || !v { + c.Token = "" + } + + // If conflicted resources still under user root, expose the relative path + userRoot := lr.ClosestAncestor.UserRoot() + userRootPath := userRoot.Uri(true).Path() + if strings.HasPrefix(c.Path, userRootPath) { + c.Path = userRoot. + Uri(false). + Join(strings.Split(strings.TrimPrefix(c.Path, userRootPath), fs.Separator)...).String() + return c + } + + // Hide sensitive information for non-owner issued lock request + c.Path = "" + return c + }) + + return session, fs.ErrLockConflict.WithError(conflicts) + } + + return session, fmt.Errorf("faield to create lock: %w", err) + } + + // Check if any ancestor is modified during `getFileByPath` and `lock`. + if err := f.ensureConsistency( + ctx, + lo.Map(lockedRequest, func(item *LockByPath, index int) *File { + return item.ClosestAncestor + })..., + ); err != nil { + return session, err + } + + return session, nil +} + +func (f *DBFS) Release(ctx context.Context, session *LockSession) error { + if session == nil { + return nil + } + + stackIndex := len(session.TokenStack) - 1 + err := f.ls.Unlock(time.Now(), lo.Map(session.TokenStack[stackIndex], func(key string, index int) string { + return session.Tokens[key] + })...) + if err == nil { + for _, key := range session.TokenStack[stackIndex] { + delete(session.Tokens, key) + } + session.TokenStack = session.TokenStack[:len(session.TokenStack)-1] + } + + return err +} + +// ensureConsistency queries database for all given files and its ancestors, make sure there's no modification in +// between. This is to make sure there's no modification between navigator's first query and lock acquisition. +func (f *DBFS) ensureConsistency(ctx context.Context, files ...*File) error { + if len(files) == 0 { + return nil + } + + // Generate a list of unique files (include ancestors) to check + uniqueFiles := make(map[int]*File) + for _, file := range files { + for root := file; root != nil; root = root.Parent { + if _, ok := uniqueFiles[root.Model.ID]; ok { + // This file and its ancestors are already included + break + } + + uniqueFiles[root.Model.ID] = root + } + } + + page := 0 + fileIds := lo.Keys(uniqueFiles) + for page >= 0 { + files, next, err := f.fileClient.GetByIDs(ctx, fileIds, page) + if err != nil { + return fmt.Errorf("failed to check file consistency: %w", err) + } + + for _, file := range files { + latest := uniqueFiles[file.ID].Model + if file.Name != latest.Name || + file.FileChildren != latest.FileChildren || + file.OwnerID != latest.OwnerID || + file.Type != latest.Type { + return fs.ErrModified. + WithError(fmt.Errorf("file %s has been modified before lock acquisition", file.Name)) + } + } + + page = next + } + + return nil +} + +// LockSessionFromCtx retrieves lock session from context. If no lock session +// found, a new empty lock session will be returned. +func LockSessionFromCtx(ctx context.Context) *LockSession { + l, _ := ctx.Value(fs.LockSessionCtxKey{}).(*LockSession) + if l == nil { + ls := &LockSession{ + Tokens: make(map[string]string), + TokenStack: make([][]string, 0), + } + + l = ls + } + + l.TokenStack = append(l.TokenStack, make([]string, 0)) + return l +} + +// Exclude removes lock from session, so that it won't be released. +func (l *LockSession) Exclude(lock *LockByPath, u *ent.User, hasher hashid.Encoder) string { + _, _, lKey := lockTupleFromUri(lock.Uri, u, hasher) + foundInCurrentStack := false + token, found := l.Tokens[lKey] + if found { + stackIndex := len(l.TokenStack) - 1 + l.TokenStack[stackIndex] = lo.Filter(l.TokenStack[stackIndex], func(t string, index int) bool { + if t == lKey { + foundInCurrentStack = true + } + return t != lKey + }) + if foundInCurrentStack { + delete(l.Tokens, lKey) + return token + } + } + + return "" +} + +func (l *LockSession) LastToken() string { + stackIndex := len(l.TokenStack) - 1 + if len(l.TokenStack[stackIndex]) == 0 { + return "" + } + return l.Tokens[l.TokenStack[stackIndex][len(l.TokenStack[stackIndex])-1]] +} + +// WithAlwaysIncludeToken returns a new context with a flag to always include token in conflic response. +func WithAlwaysIncludeToken(ctx context.Context) context.Context { + return context.WithValue(ctx, AlwaysIncludeTokenCtx{}, true) +} + +func lockTupleFromUri(uri *fs.URI, u *ent.User, hasher hashid.Encoder) (string, string, string) { + id := uri.ID(hashid.EncodeUserID(hasher, u.ID)) + if id == "" { + id = strconv.Itoa(u.ID) + } + ns := fmt.Sprintf(id + "/" + string(uri.FileSystem())) + root := uri.Path() + return ns, root, ns + "/" + root +} diff --git a/pkg/filemanager/fs/dbfs/manage.go b/pkg/filemanager/fs/dbfs/manage.go new file mode 100644 index 00000000..3ac72593 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/manage.go @@ -0,0 +1,831 @@ +package dbfs + +import ( + "context" + "fmt" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/samber/lo" + "golang.org/x/tools/container/intsets" +) + +func (f *DBFS) Create(ctx context.Context, path *fs.URI, fileType types.FileType, opts ...fs.Option) (fs.File, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + // Get navigator + navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityCreateFile, NavigatorCapabilityLockFile) + if err != nil { + return nil, err + } + + // Get most recent ancestor + var ancestor *File + if o.ancestor != nil { + ancestor = o.ancestor + } else { + ancestor, err = f.getFileByPath(ctx, navigator, path) + if err != nil && !ent.IsNotFound(err) { + return nil, fmt.Errorf("failed to get ancestor: %w", err) + } + } + + if ancestor.Uri(false).IsSame(path, hashid.EncodeUserID(f.hasher, f.user.ID)) { + if ancestor.Type() == fileType { + if o.errOnConflict { + return ancestor, fs.ErrFileExisted + } + + // Target file already exist, return it. + return ancestor, nil + } + + // File with the same name but different type already exist + return nil, fs.ErrFileExisted. + WithError(fmt.Errorf("object with the same name but different type %q already exist", ancestor.Type())) + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && ancestor.Owner().ID != f.user.ID { + return nil, fs.ErrOwnerOnly + } + + // Lock ancestor + lockedPath := ancestor.RootUri().JoinRaw(path.PathTrimmed()) + ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationCreate), + &LockByPath{lockedPath, ancestor, fileType, ""}) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return nil, err + } + + // For all ancestors in user's desired path, create folders if not exist + existedElements := ancestor.Uri(false).Elements() + desired := path.Elements() + if (len(desired)-len(existedElements) > 1) && o.noChainedCreation { + return nil, fs.ErrPathNotExist + } + + for i := len(existedElements); i < len(desired); i++ { + // Make sure parent is a folder + if !ancestor.CanHaveChildren() { + return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("parent must be a valid folder")) + } + + // Validate object name + if err := validateFileName(desired[i]); err != nil { + return nil, fs.ErrIllegalObjectName.WithError(err) + } + + if i < len(desired)-1 || fileType == types.FileTypeFolder { + args := &inventory.CreateFolderParameters{ + Owner: ancestor.Model.OwnerID, + Name: desired[i], + } + + // Apply options for last element + if i == len(desired)-1 { + if o.Metadata != nil { + args.Metadata = o.Metadata + } + args.IsSymbolic = o.isSymbolicLink + } + + // Create folder if it is not the last element or the target is a folder + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + newFolder, err := fc.CreateFolder(ctx, ancestor.Model, args) + if err != nil { + _ = inventory.Rollback(tx) + return nil, fmt.Errorf("failed to create folder %q: %w", desired[i], err) + } + + if err := inventory.Commit(tx); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit folder creation", err) + } + + ancestor = newFile(ancestor, newFolder) + } else { + file, err := f.createFile(ctx, ancestor, desired[i], fileType, o) + if err != nil { + return nil, err + } + + return file, nil + } + } + + return ancestor, nil +} + +func (f *DBFS) Rename(ctx context.Context, path *fs.URI, newName string) (fs.File, error) { + // Get navigator + navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityRenameFile, NavigatorCapabilityLockFile) + if err != nil { + return nil, err + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + return nil, fmt.Errorf("failed to get target file: %w", err) + } + oldName := target.Name() + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID { + return nil, fs.ErrOwnerOnly + } + + // Root folder cannot be modified + if target.IsRootFolder() { + return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot modify root folder")) + } + + // Validate new name + if err := validateFileName(newName); err != nil { + return nil, fs.ErrIllegalObjectName.WithError(err) + } + + // If target is a file, validate file extension + policy, err := f.getPreferredPolicy(ctx, target) + if err != nil { + return nil, err + } + + if target.Type() == types.FileTypeFile { + if err := validateExtension(newName, policy); err != nil { + return nil, fs.ErrIllegalObjectName.WithError(err) + } + } + + // Lock target + ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationRename), + &LockByPath{target.Uri(true), target, target.Type(), ""}) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return nil, err + } + + // Rename target + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + updated, err := fc.Rename(ctx, target.Model, newName) + if err != nil { + _ = inventory.Rollback(tx) + if ent.IsConstraintError(err) { + return nil, fs.ErrFileExisted.WithError(err) + } + + return nil, serializer.NewError(serializer.CodeDBError, "failed to update file", err) + } + + if target.Type() == types.FileTypeFile && !strings.EqualFold(filepath.Ext(newName), filepath.Ext(oldName)) { + if err := fc.RemoveMetadata(ctx, target.Model, ThumbDisabledKey); err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "failed to remove disabled thumbnail mark", err) + } + } + + if err := inventory.Commit(tx); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit rename change", err) + } + + return target.Replace(updated), nil +} + +func (f *DBFS) SoftDelete(ctx context.Context, path ...*fs.URI) error { + ae := serializer.NewAggregateError() + targets := make([]*File, 0, len(path)) + for _, p := range path { + // Get navigator + navigator, err := f.getNavigator(ctx, p, NavigatorCapabilitySoftDelete) + if err != nil { + ae.Add(p.String(), err) + continue + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, p) + if err != nil { + ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err)) + continue + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID { + ae.Add(p.String(), fs.ErrOwnerOnly.WithError(fmt.Errorf("only file owner can delete file without trash bin"))) + continue + } + + // Root folder cannot be deleted + if target.IsRootFolder() { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete root folder"))) + continue + } + + targets = append(targets, target) + } + + if len(targets) == 0 { + return ae.Aggregate() + } + // Lock all targets + lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath { + return &LockByPath{value.Uri(true), value, value.Type(), ""} + }) + ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationSoftDelete), lockTargets...) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return err + } + + // Start transaction to soft-delete files + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + for _, target := range targets { + // Perform soft-delete + if err := fc.SoftDelete(ctx, target.Model); err != nil { + _ = inventory.Rollback(tx) + return serializer.NewError(serializer.CodeDBError, "failed to soft-delete file", err) + } + + // Save restore uri into metadata + if err := fc.UpsertMetadata(ctx, target.Model, map[string]string{ + MetadataRestoreUri: target.Uri(true).String(), + MetadataExpectedCollectTime: strconv.FormatInt( + time.Now().Add(time.Duration(target.Owner().Edges.Group.Settings.TrashRetention)*time.Second).Unix(), + 10), + }, nil); err != nil { + _ = inventory.Rollback(tx) + return serializer.NewError(serializer.CodeDBError, "failed to update metadata", err) + } + } + + // Commit transaction + if err := inventory.Commit(tx); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to commit soft-delete change", err) + } + + return ae.Aggregate() +} + +func (f *DBFS) Delete(ctx context.Context, path []*fs.URI, opts ...fs.Option) ([]fs.Entity, error) { + o := newDbfsOption() + for _, opt := range opts { + o.apply(opt) + } + + var opt *types.EntityRecycleOption + if o.UnlinkOnly { + opt = &types.EntityRecycleOption{ + UnlinkOnly: true, + } + } + + ae := serializer.NewAggregateError() + fileNavGroup := make(map[Navigator][]*File) + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + + for _, p := range path { + // Get navigator + navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityDeleteFile, NavigatorCapabilityLockFile) + if err != nil { + ae.Add(p.String(), err) + continue + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, p) + if err != nil { + ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err)) + continue + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !o.SysSkipSoftDelete && !ok && target.Owner().ID != f.user.ID { + ae.Add(p.String(), fs.ErrOwnerOnly) + continue + } + + // Root folder cannot be deleted + if target.IsRootFolder() { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete root folder"))) + continue + } + + if _, ok := fileNavGroup[navigator]; !ok { + fileNavGroup[navigator] = make([]*File, 0) + } + fileNavGroup[navigator] = append(fileNavGroup[navigator], target) + } + + targets := lo.Flatten(lo.Values(fileNavGroup)) + if len(targets) == 0 { + return nil, ae.Aggregate() + } + // Lock all targets + lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath { + return &LockByPath{value.Uri(true), value, value.Type(), ""} + }) + ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationDelete), lockTargets...) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return nil, err + } + + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + // Delete targets + newStaleEntities, storageDiff, err := f.deleteFiles(ctx, fileNavGroup, fc, opt) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "failed to delete files", err) + } + + tx.AppendStorageDiff(storageDiff) + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit delete change", err) + } + + return newStaleEntities, ae.Aggregate() +} + +func (f *DBFS) VersionControl(ctx context.Context, path *fs.URI, versionId int, delete bool) error { + // Get navigator + navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityVersionControl) + if err != nil { + return err + } + + // Get target file + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + target, err := f.getFileByPath(ctx, navigator, path) + if err != nil { + return fmt.Errorf("failed to get target file: %w", err) + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID { + return fs.ErrOwnerOnly + } + + // Target must be a file + if target.Type() != types.FileTypeFile { + return fs.ErrNotSupportedAction.WithError(fmt.Errorf("target must be a valid file")) + } + + // Lock file + ls, err := f.acquireByPath(ctx, -1, f.user, true, fs.LockApp(fs.ApplicationVersionControl), + &LockByPath{target.Uri(true), target, target.Type(), ""}) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return err + } + + if delete { + storageDiff, err := f.deleteEntity(ctx, target, versionId) + if err != nil { + return err + } + + if err := f.userClient.ApplyStorageDiff(ctx, storageDiff); err != nil { + f.l.Error("Failed to apply storage diff after deleting version: %s", err) + } + return nil + } else { + return f.setCurrentVersion(ctx, target, versionId) + } +} + +func (f *DBFS) Restore(ctx context.Context, path ...*fs.URI) error { + ae := serializer.NewAggregateError() + targets := make([]*File, 0, len(path)) + ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true) + + for _, p := range path { + // Get navigator + navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityRestore) + if err != nil { + ae.Add(p.String(), err) + continue + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, p) + if err != nil { + ae.Add(p.String(), fmt.Errorf("failed to get file: %w", err)) + continue + } + + targets = append(targets, target) + } + + if len(targets) == 0 { + return ae.Aggregate() + } + + allTrashUriStr := lo.FilterMap(targets, func(t *File, key int) ([]*fs.URI, bool) { + if restoreUri, ok := t.Metadata()[MetadataRestoreUri]; ok { + srcUrl, err := fs.NewUriFromString(restoreUri) + if err != nil { + ae.Add(t.Uri(false).String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("invalid restore uri: %w", err))) + return nil, false + } + + return []*fs.URI{t.Uri(false), srcUrl.DirUri()}, true + } + + ae.Add(t.Uri(false).String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot restore file without required metadata mark"))) + return nil, false + }) + + // Copy each file to its original location + for _, uris := range allTrashUriStr { + if err := f.MoveOrCopy(ctx, []*fs.URI{uris[0]}, uris[1], false); err != nil { + if !ae.Merge(err) { + ae.Add(uris[0].String(), err) + } + } + } + + return ae.Aggregate() + +} + +func (f *DBFS) MoveOrCopy(ctx context.Context, path []*fs.URI, dst *fs.URI, isCopy bool) error { + targets := make([]*File, 0, len(path)) + dstNavigator, err := f.getNavigator(ctx, dst, NavigatorCapabilityLockFile) + if err != nil { + return err + } + + // Get destination file + destination, err := f.getFileByPath(ctx, dstNavigator, dst) + if err != nil { + return fmt.Errorf("faield to get destination folder: %w", err) + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && destination.Owner().ID != f.user.ID { + return fs.ErrOwnerOnly + } + + // Target must be a folder + if !destination.CanHaveChildren() { + return fs.ErrNotSupportedAction.WithError(fmt.Errorf("destination must be a valid folder")) + } + + ae := serializer.NewAggregateError() + fileNavGroup := make(map[Navigator][]*File) + dstRootPath := destination.Uri(true) + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileMetadata{}, true) + + for _, p := range path { + // Get navigator + navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityLockFile) + if err != nil { + ae.Add(p.String(), err) + continue + } + + // Check fs capability + if !canMoveOrCopyTo(p, dst, isCopy) { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move or copy file form %s to %s", p.String(), dst.String()))) + continue + } + + // Get target file + target, err := f.getFileByPath(ctx, navigator, p) + if err != nil { + ae.Add(p.String(), fmt.Errorf("failed to get file: %w", err)) + continue + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID { + ae.Add(p.String(), fs.ErrOwnerOnly) + continue + } + + // Root folder cannot be moved or copied + if target.IsRootFolder() { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move root folder"))) + continue + } + + // Cannot move or copy folder to its descendant + if target.Type() == types.FileTypeFolder && + dstRootPath.EqualOrIsDescendantOf(target.Uri(true), hashid.EncodeUserID(f.hasher, f.user.ID)) { + ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move or copy folder to itself or its descendant"))) + continue + } + + targets = append(targets, target) + if isCopy { + if _, ok := fileNavGroup[navigator]; !ok { + fileNavGroup[navigator] = make([]*File, 0) + } + fileNavGroup[navigator] = append(fileNavGroup[navigator], target) + } + } + + if len(targets) > 0 { + // Lock all targets + lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath { + return &LockByPath{value.Uri(true), value, value.Type(), ""} + }) + + // Lock destination + dstBase := destination.Uri(true) + dstLockTargets := lo.Map(targets, func(value *File, key int) *LockByPath { + return &LockByPath{dstBase.Join(value.Name()), destination, value.Type(), ""} + }) + allLockTargets := make([]*LockByPath, 0, len(targets)*2) + if !isCopy { + // For moving files from trash bin, also lock the dst with restored name. + dstRestoreTargets := lo.FilterMap(targets, func(value *File, key int) (*LockByPath, bool) { + if _, ok := value.Metadata()[MetadataRestoreUri]; ok { + return &LockByPath{dstBase.Join(value.DisplayName()), destination, value.Type(), ""}, true + } + return nil, false + }) + allLockTargets = append(allLockTargets, lockTargets...) + allLockTargets = append(allLockTargets, dstRestoreTargets...) + } + allLockTargets = append(allLockTargets, dstLockTargets...) + ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationMoveCopy), allLockTargets...) + defer func() { _ = f.Release(ctx, ls) }() + if err != nil { + return err + } + + // Start transaction to move files + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + var ( + storageDiff inventory.StorageDiff + ) + if isCopy { + _, storageDiff, err = f.copyFiles(ctx, fileNavGroup, destination, fc) + } else { + storageDiff, err = f.moveFiles(ctx, targets, destination, fc, dstNavigator) + } + + if err != nil { + _ = inventory.Rollback(tx) + return err + } + + tx.AppendStorageDiff(storageDiff) + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to commit move change", err) + } + + // TODO: after move, dbfs cache should be cleared + } + + return ae.Aggregate() +} + +func (f *DBFS) deleteEntity(ctx context.Context, target *File, entityId int) (inventory.StorageDiff, error) { + if target.PrimaryEntityID() == entityId { + return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete current version")) + } + + targetVersion, found := lo.Find(target.Entities(), func(item fs.Entity) bool { + return item.ID() == entityId + }) + if !found { + return nil, fs.ErrEntityNotExist.WithError(fmt.Errorf("version not found")) + } + + diff, err := f.fileClient.UnlinkEntity(ctx, targetVersion.Model(), target.Model, target.Owner()) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to unlink entity", err) + } + + if targetVersion.UploadSessionID() != nil { + err = f.fileClient.RemoveMetadata(ctx, target.Model, MetadataUploadSessionID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to remove upload session metadata", err) + } + } + return diff, nil +} + +func (f *DBFS) setCurrentVersion(ctx context.Context, target *File, versionId int) error { + if target.PrimaryEntityID() == versionId { + return nil + } + + targetVersion, found := lo.Find(target.Entities(), func(item fs.Entity) bool { + return item.ID() == versionId && item.Type() == types.EntityTypeVersion && item.UploadSessionID() == nil + }) + if !found { + return fs.ErrEntityNotExist.WithError(fmt.Errorf("version not found")) + } + + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + if err := f.fileClient.SetPrimaryEntity(ctx, target.Model, targetVersion.ID()); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to set primary entity", err) + } + + // Cap thumbnail entities + diff, err := fc.CapEntities(ctx, target.Model, target.Owner(), 0, types.EntityTypeThumbnail) + if err != nil { + _ = inventory.Rollback(tx) + return serializer.NewError(serializer.CodeDBError, "Failed to cap thumbnail entities", err) + } + + tx.AppendStorageDiff(diff) + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to commit set current version", err) + } + + return nil +} + +func (f *DBFS) deleteFiles(ctx context.Context, targets map[Navigator][]*File, fc inventory.FileClient, opt *types.EntityRecycleOption) ([]fs.Entity, inventory.StorageDiff, error) { + if f.user.Edges.Group == nil { + return nil, nil, fmt.Errorf("user group not loaded") + } + limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1) + allStaleEntities := make([]fs.Entity, 0, len(targets)) + storageDiff := make(inventory.StorageDiff) + for n, files := range targets { + // Let navigator use tx + reset, err := n.FollowTx(ctx) + if err != nil { + return nil, nil, err + } + + defer reset() + + // List all files to be deleted + toBeDeletedFiles := make([]*File, 0, len(files)) + if err := n.Walk(ctx, files, limit, intsets.MaxInt, func(targets []*File, level int) error { + limit -= len(targets) + toBeDeletedFiles = append(toBeDeletedFiles, targets...) + return nil + }); err != nil { + return nil, nil, fmt.Errorf("failed to walk files: %w", err) + } + + // Delete files + staleEntities, diff, err := fc.Delete(ctx, lo.Map(toBeDeletedFiles, func(item *File, index int) *ent.File { + return item.Model + }), opt) + if err != nil { + return nil, nil, fmt.Errorf("failed to delete files: %w", err) + } + storageDiff.Merge(diff) + allStaleEntities = append(allStaleEntities, lo.Map(staleEntities, func(item *ent.Entity, index int) fs.Entity { + return fs.NewEntity(item) + })...) + } + + return allStaleEntities, storageDiff, nil +} + +func (f *DBFS) copyFiles(ctx context.Context, targets map[Navigator][]*File, destination *File, fc inventory.FileClient) (map[int]*ent.File, inventory.StorageDiff, error) { + if f.user.Edges.Group == nil { + return nil, nil, fmt.Errorf("user group not loaded") + } + limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1) + capacity, err := f.Capacity(ctx, destination.Owner()) + if err != nil { + return nil, nil, fmt.Errorf("copy files: failed to destination owner capacity: %w", err) + } + + dstAncestors := lo.Map(destination.AncestorsChain(), func(item *File, index int) *ent.File { + return item.Model + }) + + // newTargetsMap is the map of between new target files in first layer, and its src file ID. + newTargetsMap := make(map[int]*ent.File) + storageDiff := make(inventory.StorageDiff) + var diff inventory.StorageDiff + for n, files := range targets { + initialDstMap := make(map[int][]*ent.File) + for _, file := range files { + initialDstMap[file.Model.FileChildren] = dstAncestors + } + + firstLayer := true + // Let navigator use tx + reset, err := n.FollowTx(ctx) + if err != nil { + return nil, nil, err + } + + defer reset() + + if err := n.Walk(ctx, files, limit, intsets.MaxInt, func(targets []*File, level int) error { + // check capacity for each file + sizeTotal := int64(0) + for _, file := range targets { + sizeTotal += file.SizeUsed() + } + + if err := f.validateUserCapacityRaw(ctx, sizeTotal, capacity); err != nil { + return fs.ErrInsufficientCapacity + } + + limit -= len(targets) + initialDstMap, diff, err = fc.Copy(ctx, lo.Map(targets, func(item *File, index int) *ent.File { + return item.Model + }), initialDstMap) + if err != nil { + if ent.IsConstraintError(err) { + return fs.ErrFileExisted.WithError(err) + } + + return serializer.NewError(serializer.CodeDBError, "Failed to copy files", err) + } + + storageDiff.Merge(diff) + + if firstLayer { + for k, v := range initialDstMap { + newTargetsMap[k] = v[0] + } + } + + capacity.Used += sizeTotal + firstLayer = false + + return nil + }); err != nil { + return nil, nil, fmt.Errorf("failed to walk files: %w", err) + } + } + + return newTargetsMap, storageDiff, nil +} + +func (f *DBFS) moveFiles(ctx context.Context, targets []*File, destination *File, fc inventory.FileClient, n Navigator) (inventory.StorageDiff, error) { + models := lo.Map(targets, func(value *File, key int) *ent.File { + return value.Model + }) + + // Change targets' parent + if err := fc.SetParent(ctx, models, destination.Model); err != nil { + if ent.IsConstraintError(err) { + return nil, fs.ErrFileExisted.WithError(err) + } + + return nil, serializer.NewError(serializer.CodeDBError, "Failed to move file", err) + } + + var ( + storageDiff inventory.StorageDiff + ) + + // For files moved out from trash bin + for _, file := range targets { + if _, ok := file.Metadata()[MetadataRestoreUri]; !ok { + continue + } + + // renaming it to its original name + if _, err := fc.Rename(ctx, file.Model, file.DisplayName()); err != nil { + if ent.IsConstraintError(err) { + return nil, fs.ErrFileExisted.WithError(err) + } + + return storageDiff, serializer.NewError(serializer.CodeDBError, "Failed to rename file from trash bin to its original name", err) + } + + // Remove trash bin metadata + if err := fc.RemoveMetadata(ctx, file.Model, MetadataRestoreUri, MetadataExpectedCollectTime); err != nil { + return storageDiff, serializer.NewError(serializer.CodeDBError, "Failed to remove trash related metadata", err) + } + } + + return storageDiff, nil +} diff --git a/pkg/filemanager/fs/dbfs/my_navigator.go b/pkg/filemanager/fs/dbfs/my_navigator.go new file mode 100644 index 00000000..8b13d968 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/my_navigator.go @@ -0,0 +1,172 @@ +package dbfs + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var myNavigatorCapability = &boolset.BooleanSet{} + +// NewMyNavigator creates a navigator for user's "my" file system. +func NewMyNavigator(u *ent.User, fileClient inventory.FileClient, userClient inventory.UserClient, l logging.Logger, + config *setting.DBFS, hasher hashid.Encoder) Navigator { + return &myNavigator{ + user: u, + l: l, + fileClient: fileClient, + userClient: userClient, + config: config, + baseNavigator: newBaseNavigator(fileClient, defaultFilter, u, hasher, config), + } +} + +type myNavigator struct { + l logging.Logger + user *ent.User + fileClient inventory.FileClient + userClient inventory.UserClient + + config *setting.DBFS + *baseNavigator + root *File + disableRecycle bool + persist func() +} + +func (n *myNavigator) Recycle() { + if n.persist != nil { + n.persist() + n.persist = nil + } + if n.root != nil && !n.disableRecycle { + n.root.Recycle() + } +} + +func (n *myNavigator) PersistState(kv cache.Driver, key string) { + n.disableRecycle = true + n.persist = func() { + kv.Set(key, n.root, ContextHintTTL) + } +} + +func (n *myNavigator) RestoreState(s State) error { + n.disableRecycle = true + if state, ok := s.(*File); ok { + n.root = state + return nil + } + + return fmt.Errorf("invalid state type: %T", s) +} + +func (n *myNavigator) To(ctx context.Context, path *fs.URI) (*File, error) { + if n.root == nil { + // Anonymous user does not have a root folder. + if inventory.IsAnonymousUser(n.user) { + return nil, ErrLoginRequired + } + + fsUid, err := n.hasher.Decode(path.ID(hashid.EncodeUserID(n.hasher, n.user.ID)), hashid.UserID) + if err != nil { + return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("invalid user id")) + } + if fsUid != n.user.ID { + return nil, ErrPermissionDenied + } + + targetUser, err := n.userClient.GetLoginUserByID(ctx, fsUid) + if err != nil { + return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("user not found: %w", err)) + } + + rootFile, err := n.fileClient.Root(ctx, targetUser) + if err != nil { + n.l.Info("User's root folder not found: %s, will initialize it.", err) + return nil, ErrFsNotInitialized + } + + n.root = newFile(nil, rootFile) + rootPath := path.Root() + n.root.Path[pathIndexRoot], n.root.Path[pathIndexUser] = rootPath, rootPath + n.root.OwnerModel = targetUser + n.root.IsUserRoot = true + n.root.CapabilitiesBs = n.Capabilities(false).Capability + } + + current, lastAncestor := n.root, n.root + elements := path.Elements() + var err error + for index, element := range elements { + lastAncestor = current + current, err = n.walkNext(ctx, current, element, index == len(elements)-1) + if err != nil { + return lastAncestor, fmt.Errorf("failed to walk into %q: %w", element, err) + } + } + + return current, nil +} + +func (n *myNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + return n.baseNavigator.children(ctx, parent, args) +} + +func (n *myNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) { + return n.baseNavigator.walkNext(ctx, root, next, isLeaf) +} + +func (n *myNavigator) Capabilities(isSearching bool) *fs.NavigatorProps { + res := &fs.NavigatorProps{ + Capability: myNavigatorCapability, + OrderDirectionOptions: fullOrderDirectionOption, + OrderByOptions: fullOrderByOption, + MaxPageSize: n.config.MaxPageSize, + } + if isSearching { + res.OrderByOptions = nil + res.OrderDirectionOptions = nil + } + + return res +} + +func (n *myNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error { + return n.baseNavigator.walk(ctx, levelFiles, limit, depth, f) +} + +func (n *myNavigator) FollowTx(ctx context.Context) (func(), error) { + if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok { + return nil, fmt.Errorf("navigator: no inherited transaction found in context") + } + newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient) + if err != nil { + return nil, err + } + + newUserClient, _, _, err := inventory.WithTx(ctx, n.userClient) + + oldFileClient, oldUserClient := n.fileClient, n.userClient + revert := func() { + n.fileClient = oldFileClient + n.userClient = oldUserClient + n.baseNavigator.fileClient = oldFileClient + } + + n.fileClient = newFileClient + n.userClient = newUserClient + n.baseNavigator.fileClient = newFileClient + return revert, nil +} + +func (n *myNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error { + return nil +} diff --git a/pkg/filemanager/fs/dbfs/navigator.go b/pkg/filemanager/fs/dbfs/navigator.go new file mode 100644 index 00000000..b3dcf50f --- /dev/null +++ b/pkg/filemanager/fs/dbfs/navigator.go @@ -0,0 +1,536 @@ +package dbfs + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" +) + +var ( + ErrFsNotInitialized = fmt.Errorf("fs not initialized") + ErrPermissionDenied = serializer.NewError(serializer.CodeNoPermissionErr, "Permission denied", nil) + + ErrShareIncorrectPassword = serializer.NewError(serializer.CodeIncorrectPassword, "Incorrect share password", nil) + ErrFileCountLimitedReached = serializer.NewError(serializer.CodeFileCountLimitedReached, "Walked file count reached limit", nil) + ErrSymbolicFolderFound = serializer.NewError(serializer.CodeNoPermissionErr, "Symbolic folder cannot be walked into", nil) + ErrLoginRequired = serializer.NewError(serializer.CodeCheckLogin, "Login required", nil) + + fullOrderByOption = []string{"name", "size", "updated_at", "created_at"} + searchLimitedOrderByOption = []string{"created_at"} + fullOrderDirectionOption = []string{"asc", "desc"} +) + +type ( + // Navigator is a navigator for database file system. + Navigator interface { + Recycle() + // To returns the file by path. If given path is not exist, returns ErrFileNotFound and most-recent ancestor. + To(ctx context.Context, path *fs.URI) (*File, error) + // Children returns the children of the parent file. + Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) + // Capabilities returns the capabilities of the navigator. + Capabilities(isSearching bool) *fs.NavigatorProps + // Walk walks the file tree until limit is reached. + Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error + // PersistState tells navigator to persist the state of the navigator before recycle. + PersistState(kv cache.Driver, key string) + // RestoreState restores the state of the navigator. + RestoreState(s State) error + // FollowTx let the navigator inherit the transaction. Return a function to reset back to previous DB client. + FollowTx(ctx context.Context) (func(), error) + // ExecuteHook performs custom operations before or after certain actions. + ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error + } + + State interface{} + + NavigatorCapability int + ListArgs struct { + Page *inventory.PaginationArgs + Search *inventory.SearchFileParameters + SharedWithMe bool + StreamCallback func([]*File) + } + // ListResult is the result of a list operation. + ListResult struct { + Files []*File + MixedType bool + Pagination *inventory.PaginationResults + RecursionLimitReached bool + SingleFileView bool + } + WalkFunc func([]*File, int) error +) + +const ( + NavigatorCapabilityCreateFile NavigatorCapability = iota + NavigatorCapabilityRenameFile + NavigatorCapability_CommunityPlacehodler1 + NavigatorCapability_CommunityPlacehodler2 + NavigatorCapability_CommunityPlacehodler3 + NavigatorCapability_CommunityPlacehodler4 + NavigatorCapabilityUploadFile + NavigatorCapabilityDownloadFile + NavigatorCapabilityUpdateMetadata + NavigatorCapabilityListChildren + NavigatorCapabilityGenerateThumb + NavigatorCapability_CommunityPlacehodler5 + NavigatorCapability_CommunityPlacehodler6 + NavigatorCapability_CommunityPlacehodler7 + NavigatorCapabilityDeleteFile + NavigatorCapabilityLockFile + NavigatorCapabilitySoftDelete + NavigatorCapabilityRestore + NavigatorCapabilityShare + NavigatorCapabilityInfo + NavigatorCapabilityVersionControl + NavigatorCapability_CommunityPlacehodler8 + NavigatorCapability_CommunityPlacehodler9 + NavigatorCapabilityEnterFolder + + searchTokenSeparator = "|" +) + +func init() { + boolset.Sets(map[NavigatorCapability]bool{ + NavigatorCapabilityCreateFile: true, + NavigatorCapabilityRenameFile: true, + NavigatorCapabilityUploadFile: true, + NavigatorCapabilityDownloadFile: true, + NavigatorCapabilityUpdateMetadata: true, + NavigatorCapabilityListChildren: true, + NavigatorCapabilityGenerateThumb: true, + NavigatorCapabilityDeleteFile: true, + NavigatorCapabilityLockFile: true, + NavigatorCapabilitySoftDelete: true, + NavigatorCapabilityShare: true, + NavigatorCapabilityInfo: true, + NavigatorCapabilityVersionControl: true, + NavigatorCapabilityEnterFolder: true, + }, myNavigatorCapability) + boolset.Sets(map[NavigatorCapability]bool{ + NavigatorCapabilityDownloadFile: true, + NavigatorCapabilityListChildren: true, + NavigatorCapabilityGenerateThumb: true, + NavigatorCapabilityLockFile: true, + NavigatorCapabilityInfo: true, + NavigatorCapabilityVersionControl: true, + NavigatorCapabilityEnterFolder: true, + }, shareNavigatorCapability) + boolset.Sets(map[NavigatorCapability]bool{ + NavigatorCapabilityListChildren: true, + NavigatorCapabilityDeleteFile: true, + NavigatorCapabilityLockFile: true, + NavigatorCapabilityRestore: true, + NavigatorCapabilityInfo: true, + }, trashNavigatorCapability) + boolset.Sets(map[NavigatorCapability]bool{ + NavigatorCapabilityListChildren: true, + NavigatorCapabilityDownloadFile: true, + NavigatorCapabilityEnterFolder: true, + }, sharedWithMeNavigatorCapability) +} + +// ==================== Base Navigator ==================== +type ( + fileFilter func(ctx context.Context, f *File) (*File, bool) + baseNavigator struct { + fileClient inventory.FileClient + listFilter fileFilter + user *ent.User + hasher hashid.Encoder + config *setting.DBFS + } +) + +var defaultFilter = func(ctx context.Context, f *File) (*File, bool) { return f, true } + +func newBaseNavigator(fileClient inventory.FileClient, filterFunc fileFilter, user *ent.User, + hasher hashid.Encoder, config *setting.DBFS) *baseNavigator { + return &baseNavigator{ + fileClient: fileClient, + listFilter: filterFunc, + user: user, + hasher: hasher, + config: config, + } +} + +func (b *baseNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) { + var model *ent.File + if root != nil { + model = root.Model + if root.IsSymbolic() { + return nil, ErrSymbolicFolderFound + } + + root.mu.Lock() + if child, ok := root.Children[next]; ok && !isLeaf { + root.mu.Unlock() + return child, nil + } + root.mu.Unlock() + } + + child, err := b.fileClient.GetChildFile(ctx, model, b.user.ID, next, isLeaf) + if err != nil { + if ent.IsNotFound(err) { + return nil, fs.ErrPathNotExist.WithError(err) + } + + return nil, fmt.Errorf("faield to get child %q: %w", next, err) + } + + return newFile(root, child), nil +} + +func (b *baseNavigator) walkUp(ctx context.Context, child *File) (*File, error) { + parent, err := b.fileClient.GetParentFile(ctx, child.Model, false) + if err != nil { + return nil, fmt.Errorf("faield to get Parent for %q: %w", child.Name(), err) + } + + return newParentFile(parent, child), nil +} + +func (b *baseNavigator) children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + var model *ent.File + if parent != nil { + model = parent.Model + if parent.Model.Type != int(types.FileTypeFolder) { + return nil, fs.ErrPathNotExist + } + + if parent.IsSymbolic() { + return nil, ErrSymbolicFolderFound + } + + parent.Path[pathIndexUser] = parent.Uri(false) + } + + if args.Search != nil { + return b.search(ctx, parent, args) + } + + children, err := b.fileClient.GetChildFiles(ctx, &inventory.ListFileParameters{ + PaginationArgs: args.Page, + SharedWithMe: args.SharedWithMe, + }, b.user.ID, model) + if err != nil { + return nil, fmt.Errorf("failed to get children: %w", err) + } + + return &ListResult{ + Files: lo.FilterMap(children.Files, func(model *ent.File, index int) (*File, bool) { + f := newFile(parent, model) + return b.listFilter(ctx, f) + }), + MixedType: children.MixedType, + Pagination: children.PaginationResults, + }, nil +} + +func (b *baseNavigator) walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error { + walked := 0 + if len(levelFiles) == 0 { + return nil + } + + owner := levelFiles[0].Owner() + + level := 0 + for walked <= limit && depth >= 0 { + if len(levelFiles) == 0 { + break + } + + stop := false + depth-- + if len(levelFiles) > limit-walked { + levelFiles = levelFiles[:limit-walked] + stop = true + } + if err := f(levelFiles, level); err != nil { + return err + } + + if stop { + return ErrFileCountLimitedReached + } + + walked += len(levelFiles) + folders := lo.Filter(levelFiles, func(f *File, index int) bool { + return f.Model.Type == int(types.FileTypeFolder) && !f.IsSymbolic() + }) + + if walked >= limit || len(folders) == 0 { + break + } + + levelFiles = levelFiles[:0] + leftCredit := limit - walked + parents := lo.SliceToMap(folders, func(file *File) (int, *File) { + return file.Model.ID, file + }) + for leftCredit > 0 { + token := "" + res, err := b.fileClient.GetChildFiles(ctx, + &inventory.ListFileParameters{ + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageToken: token, + PageSize: leftCredit, + }, + MixedType: true, + }, + owner.ID, + lo.Map(folders, func(item *File, index int) *ent.File { + return item.Model + })...) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to list children", err) + } + + leftCredit -= len(res.Files) + + levelFiles = append(levelFiles, lo.Map(res.Files, func(model *ent.File, index int) *File { + p := parents[model.FileChildren] + return newFile(p, model) + })...) + + // All files listed + if res.NextPageToken == "" { + break + } + + token = res.NextPageToken + } + level++ + } + + if walked >= limit { + return ErrFileCountLimitedReached + } + + return nil +} + +func (b *baseNavigator) search(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + if parent == nil { + // Performs mega search for all files in trash fs. + children, err := b.fileClient.GetChildFiles(ctx, &inventory.ListFileParameters{ + PaginationArgs: args.Page, + MixedType: true, + Search: args.Search, + SharedWithMe: args.SharedWithMe, + }, b.user.ID, nil) + if err != nil { + return nil, fmt.Errorf("failed to get children: %w", err) + } + + return &ListResult{ + Files: lo.FilterMap(children.Files, func(model *ent.File, index int) (*File, bool) { + f := newFile(parent, model) + return b.listFilter(ctx, f) + }), + MixedType: children.MixedType, + Pagination: children.PaginationResults, + }, nil + } + // Performs recursive search for all files under the given folder. + walkedFolder := 1 + parents := []map[int]*File{{parent.Model.ID: parent}} + startLevel, innerPageToken, err := parseSearchPageToken(args.Page.PageToken) + if err != nil { + return nil, err + } + args.Page.PageToken = innerPageToken + + stepLevel := func(level int) (bool, error) { + token := "" + // We don't need metadata in level search. + listCtx := context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, nil) + for walkedFolder <= b.config.MaxRecursiveSearchedFolder { + // TODO: chunk parents into 30000 per group + res, err := b.fileClient.GetChildFiles(listCtx, + &inventory.ListFileParameters{ + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageToken: token, + }, + FolderOnly: true, + }, + parent.Model.OwnerID, + lo.MapToSlice(parents[level], func(k int, f *File) *ent.File { + return f.Model + })...) + if err != nil { + return false, serializer.NewError(serializer.CodeDBError, "Failed to list children", err) + } + + parents = append(parents, lo.SliceToMap( + lo.FilterMap(res.Files, func(model *ent.File, index int) (*File, bool) { + p := parents[level][model.FileChildren] + f := newFile(p, model) + f.Path[pathIndexUser] = p.Uri(false).Join(model.Name) + return f, true + }), + func(f *File) (int, *File) { + return f.Model.ID, f + })) + + walkedFolder += len(parents[level+1]) + if res.NextPageToken == "" { + break + } + + token = res.NextPageToken + } + + if len(parents) <= level+1 || len(parents[level+1]) == 0 { + // All possible folders is searched + return true, nil + } + + return false, nil + } + + // We need to walk from root folder to get the correct level. + for level := 0; level < startLevel; level++ { + stop, err := stepLevel(level) + if err != nil { + return nil, err + } + + if stop { + return &ListResult{}, nil + } + } + + // Search files starting from current level + res := make([]*File, 0, args.Page.PageSize) + args.Page.UseCursorPagination = true + originalPageSize := args.Page.PageSize + stop := false + for len(res) < originalPageSize && walkedFolder <= b.config.MaxRecursiveSearchedFolder { + // Only requires minimum number of files + args.Page.PageSize = min(originalPageSize, originalPageSize-len(res)) + searchRes, err := b.fileClient.GetChildFiles(ctx, + &inventory.ListFileParameters{ + PaginationArgs: args.Page, + MixedType: true, + Search: args.Search, + }, + parent.Model.OwnerID, + lo.MapToSlice(parents[startLevel], func(k int, f *File) *ent.File { + return f.Model + })...) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to search files", err) + } + + newRes := lo.FilterMap(searchRes.Files, func(model *ent.File, index int) (*File, bool) { + p := parents[startLevel][model.FileChildren] + f := newFile(p, model) + f.Path[pathIndexUser] = p.Uri(false).Join(model.Name) + return b.listFilter(ctx, f) + }) + res = append(res, newRes...) + if args.StreamCallback != nil { + args.StreamCallback(newRes) + } + + args.Page.PageToken = searchRes.NextPageToken + // If no more results under current level, move to next level + if args.Page.PageToken == "" { + if len(res) == originalPageSize { + // Current page is full, no need to search more + startLevel++ + break + } + + finished, err := stepLevel(startLevel) + if err != nil { + return nil, err + } + + if finished { + stop = true + // No more folders under next level, all result is presented + break + } + + startLevel++ + } + } + + if args.StreamCallback != nil { + // Clear res if it's streamed + res = res[:0] + } + + searchRes := &ListResult{ + Files: res, + MixedType: true, + Pagination: &inventory.PaginationResults{IsCursor: true}, + RecursionLimitReached: walkedFolder > b.config.MaxRecursiveSearchedFolder, + } + + if walkedFolder <= b.config.MaxRecursiveSearchedFolder && !stop { + searchRes.Pagination.NextPageToken = fmt.Sprintf("%d%s%s", startLevel, searchTokenSeparator, args.Page.PageToken) + } + + return searchRes, nil +} + +func parseSearchPageToken(token string) (int, string, error) { + if token == "" { + return 0, "", nil + } + + tokens := strings.Split(token, searchTokenSeparator) + if len(tokens) != 2 { + return 0, "", fmt.Errorf("invalid page token") + } + + level, err := strconv.Atoi(tokens[0]) + if err != nil || level < 0 { + return 0, "", fmt.Errorf("invalid page token level") + } + + return level, tokens[1], nil +} + +func newMyUri() *fs.URI { + res, _ := fs.NewUriFromString(constants.CloudreveScheme + "://" + string(constants.FileSystemMy)) + return res +} + +func newMyIDUri(uid string) *fs.URI { + res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s@%s", constants.CloudreveScheme, uid, constants.FileSystemMy)) + return res +} + +func newTrashUri(name string) *fs.URI { + res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s", constants.CloudreveScheme, constants.FileSystemTrash)) + return res.Join(name) +} + +func newSharedWithMeUri(id string) *fs.URI { + res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s", constants.CloudreveScheme, constants.FileSystemSharedWithMe)) + return res.Join(id) +} diff --git a/pkg/filemanager/fs/dbfs/options.go b/pkg/filemanager/fs/dbfs/options.go new file mode 100644 index 00000000..e2b8ce70 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/options.go @@ -0,0 +1,171 @@ +package dbfs + +import ( + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" +) + +type dbfsOption struct { + *fs.FsOption + loadFolderSummary bool + extendedInfo bool + loadFilePublicMetadata bool + loadFileShareIfOwned bool + loadEntityUser bool + loadFileEntities bool + useCursorPagination bool + pageToken string + preferredStoragePolicy *ent.StoragePolicy + errOnConflict bool + previousVersion string + removeStaleEntities bool + requiredCapabilities []NavigatorCapability + generateContextHint bool + isSymbolicLink bool + noChainedCreation bool + streamListResponseCallback func(parent fs.File, file []fs.File) + ancestor *File +} + +func newDbfsOption() *dbfsOption { + return &dbfsOption{ + FsOption: &fs.FsOption{}, + } +} + +func (o *dbfsOption) apply(opt fs.Option) { + if fsOpt, ok := opt.(fs.OptionFunc); ok { + fsOpt.Apply(o.FsOption) + } else if dbfsOpt, ok := opt.(optionFunc); ok { + dbfsOpt.Apply(o) + } +} + +type optionFunc func(*dbfsOption) + +func (f optionFunc) Apply(o any) { + if dbfsO, ok := o.(*dbfsOption); ok { + f(dbfsO) + } +} + +// WithFilePublicMetadata enables loading file public metadata. +func WithFilePublicMetadata() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.loadFilePublicMetadata = true + }) +} + +// WithContextHint enables generating context hint for the list operation. +func WithContextHint() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.generateContextHint = true + }) +} + +// WithFileEntities enables loading file entities. +func WithFileEntities() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.loadFileEntities = true + }) +} + +// WithCursorPagination enables cursor pagination for the list operation. +func WithCursorPagination(pageToken string) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.useCursorPagination = true + o.pageToken = pageToken + }) +} + +// WithPreferredStoragePolicy sets the preferred storage policy for the upload operation. +func WithPreferredStoragePolicy(policy *ent.StoragePolicy) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.preferredStoragePolicy = policy + }) +} + +// WithErrorOnConflict sets to throw error on conflict for the create operation. +func WithErrorOnConflict() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.errOnConflict = true + }) +} + +// WithPreviousVersion sets the previous version for the update operation. +func WithPreviousVersion(version string) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.previousVersion = version + }) +} + +// WithRemoveStaleEntities sets to remove stale entities for the update operation. +func WithRemoveStaleEntities() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.removeStaleEntities = true + }) +} + +// WithRequiredCapabilities sets the required capabilities for operations. +func WithRequiredCapabilities(capabilities ...NavigatorCapability) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.requiredCapabilities = capabilities + }) +} + +// WithNoChainedCreation sets to disable chained creation for the create operation. This +// will require parent folder existed before creating new files under it. +func WithNoChainedCreation() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.noChainedCreation = true + }) +} + +// WithFileShareIfOwned enables loading file share link if the file is owned by the user. +func WithFileShareIfOwned() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.loadFileShareIfOwned = true + }) +} + +// WithStreamListResponseCallback sets the callback for handling stream list response. +func WithStreamListResponseCallback(callback func(parent fs.File, file []fs.File)) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.streamListResponseCallback = callback + }) +} + +// WithSymbolicLink sets the file is a symbolic link. +func WithSymbolicLink() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.isSymbolicLink = true + }) +} + +// WithExtendedInfo enables loading extended info for the file. +func WithExtendedInfo() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.extendedInfo = true + }) +} + +// WithLoadFolderSummary enables loading folder summary. +func WithLoadFolderSummary() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.loadFolderSummary = true + }) +} + +// WithEntityUser enables loading entity user. +func WithEntityUser() fs.Option { + return optionFunc(func(o *dbfsOption) { + o.loadEntityUser = true + }) +} + +// WithAncestor sets most recent ancestor for creating files +func WithAncestor(f *File) fs.Option { + return optionFunc(func(o *dbfsOption) { + o.ancestor = f + }) +} diff --git a/pkg/filemanager/fs/dbfs/share_navigator.go b/pkg/filemanager/fs/dbfs/share_navigator.go new file mode 100644 index 00000000..61bfc3b6 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/share_navigator.go @@ -0,0 +1,324 @@ +package dbfs + +import ( + "context" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var ( + ErrShareNotFound = serializer.NewError(serializer.CodeNotFound, "Shared file does not exist", nil) + ErrNotPurchased = serializer.NewError(serializer.CodePurchaseRequired, "You need to purchased this share", nil) +) + +const ( + PurchaseTicketHeader = constants.CrHeaderPrefix + "Purchase-Ticket" +) + +var shareNavigatorCapability = &boolset.BooleanSet{} + +// NewShareNavigator creates a navigator for user's "shared" file system. +func NewShareNavigator(u *ent.User, fileClient inventory.FileClient, shareClient inventory.ShareClient, + l logging.Logger, config *setting.DBFS, hasher hashid.Encoder) Navigator { + n := &shareNavigator{ + user: u, + l: l, + fileClient: fileClient, + shareClient: shareClient, + config: config, + } + n.baseNavigator = newBaseNavigator(fileClient, defaultFilter, u, hasher, config) + return n +} + +type ( + shareNavigator struct { + l logging.Logger + user *ent.User + fileClient inventory.FileClient + shareClient inventory.ShareClient + config *setting.DBFS + + *baseNavigator + shareRoot *File + singleFileShare bool + ownerRoot *File + share *ent.Share + owner *ent.User + disableRecycle bool + persist func() + } + + shareNavigatorState struct { + ShareRoot *File + OwnerRoot *File + SingleFileShare bool + Share *ent.Share + Owner *ent.User + } +) + +func (n *shareNavigator) PersistState(kv cache.Driver, key string) { + n.disableRecycle = true + n.persist = func() { + kv.Set(key, shareNavigatorState{ + ShareRoot: n.shareRoot, + OwnerRoot: n.ownerRoot, + SingleFileShare: n.singleFileShare, + Share: n.share, + Owner: n.owner, + }, ContextHintTTL) + } +} + +func (n *shareNavigator) RestoreState(s State) error { + n.disableRecycle = true + if state, ok := s.(shareNavigatorState); ok { + n.shareRoot = state.ShareRoot + n.ownerRoot = state.OwnerRoot + n.singleFileShare = state.SingleFileShare + n.share = state.Share + n.owner = state.Owner + return nil + } + + return fmt.Errorf("invalid state type: %T", s) +} + +func (n *shareNavigator) Recycle() { + if n.persist != nil { + n.persist() + n.persist = nil + } + + if !n.disableRecycle { + if n.ownerRoot != nil { + n.ownerRoot.Recycle() + } else if n.shareRoot != nil { + n.shareRoot.Recycle() + } + } +} + +func (n *shareNavigator) Root(ctx context.Context, path *fs.URI) (*File, error) { + ctx = context.WithValue(ctx, inventory.LoadShareUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareFile{}, true) + share, err := n.shareClient.GetByHashID(ctx, path.ID(hashid.EncodeUserID(n.hasher, n.user.ID))) + if err != nil { + return nil, ErrShareNotFound.WithError(err) + } + + if err := inventory.IsValidShare(share); err != nil { + return nil, ErrShareNotFound.WithError(err) + } + + n.owner = share.Edges.User + + // Check password + if share.Password != "" && share.Password != path.Password() { + return nil, ErrShareIncorrectPassword + } + + // Share permission setting should overwrite root folder's permission + n.shareRoot = newFile(nil, share.Edges.File) + + // Find the user side root of the file. + ownerRoot, err := n.findRoot(ctx, n.shareRoot) + if err != nil { + return nil, err + } + + if n.shareRoot.Type() == types.FileTypeFile { + n.singleFileShare = true + n.shareRoot = n.shareRoot.Parent + } + + n.shareRoot.Path[pathIndexUser] = path.Root() + n.shareRoot.OwnerModel = n.owner + n.shareRoot.IsUserRoot = true + n.shareRoot.CapabilitiesBs = n.Capabilities(false).Capability + + // Check if any ancestors is deleted + if ownerRoot.Name() != inventory.RootFolderName { + return nil, ErrShareNotFound + } + + if n.user.ID != n.owner.ID && !n.user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionShareDownload)) { + return nil, serializer.NewError( + serializer.CodeNoPermissionErr, + fmt.Sprintf("You don't have permission to access share links"), + err, + ) + } + + n.ownerRoot = ownerRoot + n.ownerRoot.Path[pathIndexRoot] = newMyIDUri(hashid.EncodeUserID(n.hasher, n.owner.ID)) + n.share = share + return n.shareRoot, nil +} + +func (n *shareNavigator) To(ctx context.Context, path *fs.URI) (*File, error) { + if n.shareRoot == nil { + root, err := n.Root(ctx, path) + if err != nil { + return nil, err + } + + n.shareRoot = root + } + + current, lastAncestor := n.shareRoot, n.shareRoot + elements := path.Elements() + + // If target is root of single file share, the root itself is the target. + if len(elements) <= 1 && n.singleFileShare { + file, err := n.latestSharedSingleFile(ctx) + if err != nil { + return nil, err + } + + if len(elements) == 1 && file.Name() != elements[0] { + return nil, fs.ErrPathNotExist + } + + return file, nil + } + + var err error + for index, element := range elements { + lastAncestor = current + current, err = n.walkNext(ctx, current, element, index == len(elements)-1) + if err != nil { + return lastAncestor, fmt.Errorf("failed to walk into %q: %w", element, err) + } + } + + return current, nil +} + +func (n *shareNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) { + nextFile, err := n.baseNavigator.walkNext(ctx, root, next, isLeaf) + if err != nil { + return nil, err + } + + return nextFile, nil +} + +func (n *shareNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + if n.singleFileShare { + file, err := n.latestSharedSingleFile(ctx) + if err != nil { + return nil, err + } + + return &ListResult{ + Files: []*File{file}, + Pagination: &inventory.PaginationResults{}, + SingleFileView: true, + }, nil + } + + return n.baseNavigator.children(ctx, parent, args) +} + +func (n *shareNavigator) latestSharedSingleFile(ctx context.Context) (*File, error) { + if n.singleFileShare { + file, err := n.fileClient.GetByID(ctx, n.share.Edges.File.ID) + if err != nil { + return nil, err + } + + f := newFile(n.shareRoot, file) + f.OwnerModel = n.shareRoot.OwnerModel + + return f, nil + } + + return nil, fs.ErrPathNotExist +} + +func (n *shareNavigator) Capabilities(isSearching bool) *fs.NavigatorProps { + res := &fs.NavigatorProps{ + Capability: shareNavigatorCapability, + OrderDirectionOptions: fullOrderDirectionOption, + OrderByOptions: fullOrderByOption, + MaxPageSize: n.config.MaxPageSize, + } + + if isSearching { + res.OrderByOptions = nil + res.OrderDirectionOptions = nil + } + + return res +} + +func (n *shareNavigator) FollowTx(ctx context.Context) (func(), error) { + if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok { + return nil, fmt.Errorf("navigator: no inherited transaction found in context") + } + newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient) + if err != nil { + return nil, err + } + + newSharClient, _, _, err := inventory.WithTx(ctx, n.shareClient) + + oldFileClient, oldShareClient := n.fileClient, n.shareClient + revert := func() { + n.fileClient = oldFileClient + n.shareClient = oldShareClient + n.baseNavigator.fileClient = oldFileClient + } + + n.fileClient = newFileClient + n.shareClient = newSharClient + n.baseNavigator.fileClient = newFileClient + return revert, nil +} + +func (n *shareNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error { + switch hookType { + case fs.HookTypeBeforeDownload: + if n.singleFileShare { + return n.shareClient.Downloaded(ctx, n.share) + } + } + return nil +} + +// findRoot finds the root folder of the given child. +func (n *shareNavigator) findRoot(ctx context.Context, child *File) (*File, error) { + root := child + for { + newRoot, err := n.baseNavigator.walkUp(ctx, root) + if err != nil { + if !ent.IsNotFound(err) { + return nil, err + } + + break + } + + root = newRoot + } + + return root, nil +} + +func (n *shareNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error { + return n.baseNavigator.walk(ctx, levelFiles, limit, depth, f) +} diff --git a/pkg/filemanager/fs/dbfs/sharewithme_navigator.go b/pkg/filemanager/fs/dbfs/sharewithme_navigator.go new file mode 100644 index 00000000..4d896b77 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/sharewithme_navigator.go @@ -0,0 +1,141 @@ +package dbfs + +import ( + "context" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var sharedWithMeNavigatorCapability = &boolset.BooleanSet{} + +// NewSharedWithMeNavigator creates a navigator for user's "shared with me" file system. +func NewSharedWithMeNavigator(u *ent.User, fileClient inventory.FileClient, l logging.Logger, + config *setting.DBFS, hasher hashid.Encoder) Navigator { + n := &sharedWithMeNavigator{ + user: u, + l: l, + fileClient: fileClient, + config: config, + hasher: hasher, + } + n.baseNavigator = newBaseNavigator(fileClient, defaultFilter, u, hasher, config) + return n +} + +type sharedWithMeNavigator struct { + l logging.Logger + user *ent.User + fileClient inventory.FileClient + config *setting.DBFS + hasher hashid.Encoder + + root *File + *baseNavigator +} + +func (t *sharedWithMeNavigator) Recycle() { + +} + +func (n *sharedWithMeNavigator) PersistState(kv cache.Driver, key string) { +} + +func (n *sharedWithMeNavigator) RestoreState(s State) error { + return nil +} + +func (t *sharedWithMeNavigator) To(ctx context.Context, path *fs.URI) (*File, error) { + // Anonymous user does not have a trash folder. + if inventory.IsAnonymousUser(t.user) { + return nil, ErrLoginRequired + } + + elements := path.Elements() + if len(elements) > 0 { + // Shared with me folder is a flatten tree, only root can be accessed. + return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("invalid Path %q", path)) + } + + if t.root == nil { + rootFile, err := t.fileClient.Root(ctx, t.user) + if err != nil { + t.l.Info("User's root folder not found: %s, will initialize it.", err) + return nil, ErrFsNotInitialized + } + + t.root = newFile(nil, rootFile) + rootPath := newSharedWithMeUri("") + t.root.Path[pathIndexRoot], t.root.Path[pathIndexUser] = rootPath, rootPath + t.root.OwnerModel = t.user + t.root.IsUserRoot = true + t.root.CapabilitiesBs = t.Capabilities(false).Capability + } + + return t.root, nil +} + +func (t *sharedWithMeNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + args.SharedWithMe = true + res, err := t.baseNavigator.children(ctx, nil, args) + if err != nil { + return nil, err + } + + // Adding user uri for each file. + for i := 0; i < len(res.Files); i++ { + res.Files[i].Path[pathIndexUser] = newSharedWithMeUri(hashid.EncodeFileID(t.hasher, res.Files[i].Model.ID)) + } + + return res, nil +} + +func (t *sharedWithMeNavigator) Capabilities(isSearching bool) *fs.NavigatorProps { + res := &fs.NavigatorProps{ + Capability: sharedWithMeNavigatorCapability, + OrderDirectionOptions: fullOrderDirectionOption, + OrderByOptions: fullOrderByOption, + MaxPageSize: t.config.MaxPageSize, + } + + if isSearching { + res.OrderByOptions = searchLimitedOrderByOption + } + + return res +} + +func (t *sharedWithMeNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error { + return errors.New("not implemented") +} + +func (n *sharedWithMeNavigator) FollowTx(ctx context.Context) (func(), error) { + if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok { + return nil, fmt.Errorf("navigator: no inherited transaction found in context") + } + newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient) + if err != nil { + return nil, err + } + + oldFileClient := n.fileClient + revert := func() { + n.fileClient = oldFileClient + n.baseNavigator.fileClient = oldFileClient + } + + n.fileClient = newFileClient + n.baseNavigator.fileClient = newFileClient + return revert, nil +} + +func (n *sharedWithMeNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error { + return nil +} diff --git a/pkg/filemanager/fs/dbfs/trash_navigator.go b/pkg/filemanager/fs/dbfs/trash_navigator.go new file mode 100644 index 00000000..1bf45c66 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/trash_navigator.go @@ -0,0 +1,137 @@ +package dbfs + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var trashNavigatorCapability = &boolset.BooleanSet{} + +// NewTrashNavigator creates a navigator for user's "trash" file system. +func NewTrashNavigator(u *ent.User, fileClient inventory.FileClient, l logging.Logger, config *setting.DBFS, + hasher hashid.Encoder) Navigator { + return &trashNavigator{ + user: u, + l: l, + fileClient: fileClient, + config: config, + baseNavigator: newBaseNavigator(fileClient, defaultFilter, u, hasher, config), + } +} + +type trashNavigator struct { + l logging.Logger + user *ent.User + fileClient inventory.FileClient + config *setting.DBFS + + *baseNavigator +} + +func (t *trashNavigator) Recycle() { + +} + +func (n *trashNavigator) PersistState(kv cache.Driver, key string) { +} + +func (n *trashNavigator) RestoreState(s State) error { + return nil +} + +func (t *trashNavigator) To(ctx context.Context, path *fs.URI) (*File, error) { + // Anonymous user does not have a trash folder. + if inventory.IsAnonymousUser(t.user) { + return nil, ErrLoginRequired + } + + elements := path.Elements() + if len(elements) > 1 { + // Trash folder is a flatten tree, only 1 layer is supported. + return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("invalid Path %q", path)) + } + + if len(elements) == 0 { + // Trash folder has no root. + return nil, nil + } + + current, err := t.walkNext(ctx, nil, elements[0], true) + if err != nil { + return nil, fmt.Errorf("failed to walk into %q: %w", elements[0], err) + } + + current.Path[pathIndexUser] = newTrashUri(current.Model.Name) + current.Path[pathIndexRoot] = current.Path[pathIndexUser] + current.OwnerModel = t.user + return current, nil +} + +func (t *trashNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) { + if parent != nil { + return nil, fs.ErrPathNotExist + } + + res, err := t.baseNavigator.children(ctx, nil, args) + if err != nil { + return nil, err + } + + // Adding user uri for each file. + for i := 0; i < len(res.Files); i++ { + res.Files[i].Path[pathIndexUser] = newTrashUri(res.Files[i].Model.Name) + } + + return res, nil +} + +func (t *trashNavigator) Capabilities(isSearching bool) *fs.NavigatorProps { + res := &fs.NavigatorProps{ + Capability: trashNavigatorCapability, + OrderDirectionOptions: fullOrderDirectionOption, + OrderByOptions: fullOrderByOption, + MaxPageSize: t.config.MaxPageSize, + } + + if isSearching { + res.OrderByOptions = searchLimitedOrderByOption + } + + return res +} + +func (t *trashNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error { + return t.baseNavigator.walk(ctx, levelFiles, limit, depth, f) +} + +func (n *trashNavigator) FollowTx(ctx context.Context) (func(), error) { + if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok { + return nil, fmt.Errorf("navigator: no inherited transaction found in context") + } + newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient) + if err != nil { + return nil, err + } + + oldFileClient := n.fileClient + revert := func() { + n.fileClient = oldFileClient + n.baseNavigator.fileClient = oldFileClient + } + + n.fileClient = newFileClient + n.baseNavigator.fileClient = newFileClient + return revert, nil +} + +func (n *trashNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error { + return nil +} diff --git a/pkg/filemanager/fs/dbfs/upload.go b/pkg/filemanager/fs/dbfs/upload.go new file mode 100644 index 00000000..8fcbd93d --- /dev/null +++ b/pkg/filemanager/fs/dbfs/upload.go @@ -0,0 +1,364 @@ +package dbfs + +import ( + "context" + "fmt" + "math" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" +) + +func (f *DBFS) PrepareUpload(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (*fs.UploadSession, error) { + // Get navigator + navigator, err := f.getNavigator(ctx, req.Props.Uri, NavigatorCapabilityUploadFile, NavigatorCapabilityLockFile) + if err != nil { + return nil, err + } + + // Get most recent ancestor or target file + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + ancestor, err := f.getFileByPath(ctx, navigator, req.Props.Uri) + if err != nil && !ent.IsNotFound(err) { + return nil, fmt.Errorf("failed to get ancestor: %w", err) + } + + if ancestor.IsSymbolic() { + return nil, ErrSymbolicFolderFound + } + + fileExisted := false + if ancestor.Uri(false).IsSame(req.Props.Uri, hashid.EncodeUserID(f.hasher, f.user.ID)) { + fileExisted = true + } + + // If file already exist, and update operation is suspended or existing file is not a file + if fileExisted && (req.Props.EntityType == nil || ancestor.Type() != types.FileTypeFile) { + return nil, fs.ErrFileExisted + } + + // If file not exist, only empty entity / version entity is allowed + if !fileExisted && (req.Props.EntityType != nil && *req.Props.EntityType != types.EntityTypeVersion) { + return nil, fs.ErrPathNotExist + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && ancestor.OwnerID() != f.user.ID { + return nil, fs.ErrOwnerOnly + } + + // Lock target + lockedPath := ancestor.RootUri().JoinRaw(req.Props.Uri.PathTrimmed()) + lr := &LockByPath{lockedPath, ancestor, types.FileTypeFile, ""} + ls, err := f.acquireByPath(ctx, time.Until(req.Props.ExpireAt), f.user, false, fs.LockApp(fs.ApplicationUpload), lr) + defer func() { _ = f.Release(ctx, ls) }() + ctx = fs.LockSessionToContext(ctx, ls) + if err != nil { + return nil, err + } + + // Get parent folder storage policy and performs validation + policy, err := f.getPreferredPolicy(ctx, ancestor) + if err != nil { + return nil, err + } + + // validate upload request + if err := validateNewFile(req.Props.Uri.Name(), req.Props.Size, policy); err != nil { + return nil, err + } + + // Validate available capacity + if err := f.validateUserCapacity(ctx, req.Props.Size, ancestor.Owner()); err != nil { + return nil, err + } + + // Generate save path by storage policy + isThumbnailAndPolicyNotAvailable := policy.ID != ancestor.Model.StoragePolicyFiles && (req.Props.EntityType != nil && *req.Props.EntityType == types.EntityTypeThumbnail) + if req.Props.SavePath == "" || isThumbnailAndPolicyNotAvailable { + req.Props.SavePath = generateSavePath(policy, req, f.user) + if isThumbnailAndPolicyNotAvailable { + req.Props.SavePath = fmt.Sprintf( + "%s.%s%s", + req.Props.SavePath, + util.RandStringRunes(16), + f.settingClient.ThumbEntitySuffix(ctx)) + } + } + + // Create upload placeholder + var ( + fileId int + entityId int + lockToken string + targetFile *ent.File + ) + fc, dbTx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + if fileExisted { + entityType := types.EntityTypeVersion + if req.Props.EntityType != nil { + entityType = *req.Props.EntityType + } + entity, err := f.CreateEntity(ctx, ancestor, policy, entityType, req, + WithPreviousVersion(req.Props.PreviousVersion), + fs.WithUploadRequest(req), + WithRemoveStaleEntities(), + ) + if err != nil { + _ = inventory.Rollback(dbTx) + return nil, fmt.Errorf("failed to create new entity: %w", err) + } + fileId = ancestor.ID() + entityId = entity.ID() + targetFile = ancestor.Model + lockToken = ls.Exclude(lr, f.user, f.hasher) + } else { + uploadPlaceholder, err := f.Create(ctx, req.Props.Uri, types.FileTypeFile, + fs.WithUploadRequest(req), + WithPreferredStoragePolicy(policy), + WithErrorOnConflict(), + WithAncestor(ancestor), + ) + if err != nil { + _ = inventory.Rollback(dbTx) + return nil, fmt.Errorf("failed to create upload placeholder: %w", err) + } + + fileId = uploadPlaceholder.ID() + entityId = uploadPlaceholder.Entities()[0].ID() + targetFile = uploadPlaceholder.(*File).Model + lockToken = ls.Exclude(lr, f.user, f.hasher) + } + + // create metadata to record uploading entity id + if err := fc.UpsertMetadata(ctx, targetFile, map[string]string{ + MetadataUploadSessionID: req.Props.UploadSessionID, + }, nil); err != nil { + _ = inventory.Rollback(dbTx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update upload session metadata", err) + } + + if err := inventory.CommitWithStorageDiff(ctx, dbTx, f.l, f.userClient); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit file upload preparation", err) + } + + session := &fs.UploadSession{ + Props: &fs.UploadProps{ + Uri: req.Props.Uri, + Size: req.Props.Size, + SavePath: req.Props.SavePath, + LastModified: req.Props.LastModified, + UploadSessionID: req.Props.UploadSessionID, + ExpireAt: req.Props.ExpireAt, + EntityType: req.Props.EntityType, + }, + FileID: fileId, + NewFileCreated: !fileExisted, + EntityID: entityId, + UID: f.user.ID, + Policy: policy, + CallbackSecret: util.RandStringRunes(32), + LockToken: lockToken, // Prevent lock being released. + } + + // TODO: frontend should create new upload session if resumed session does not exist. + return session, nil +} + +func (f *DBFS) CompleteUpload(ctx context.Context, session *fs.UploadSession) (fs.File, error) { + // Get placeholder file + file, err := f.Get(ctx, session.Props.Uri, WithFileEntities()) + if err != nil { + return nil, fmt.Errorf("failed to get placeholder file: %w", err) + } + + filePrivate := file.(*File) + + // Confirm locks on placeholder file + if session.LockToken != "" { + release, ls, err := f.ConfirmLock(ctx, file, file.Uri(false), session.LockToken) + if err != nil { + return nil, fs.ErrLockExpired.WithError(err) + } + + release() + ctx = fs.LockSessionToContext(ctx, ls) + } + + // Update placeholder entity to actual desired entity + entityType := types.EntityTypeVersion + if session.Props.EntityType != nil { + entityType = *session.Props.EntityType + } + + // Check version retention policy + owner := filePrivate.Owner() + // Max allowed versions + maxVersions := 1 + if entityType == types.EntityTypeVersion && + owner.Settings.VersionRetention && + (len(owner.Settings.VersionRetentionExt) == 0 || util.IsInExtensionList(owner.Settings.VersionRetentionExt, file.Name())) { + // Retention is enabled for this file + maxVersions = owner.Settings.VersionRetentionMax + if maxVersions == 0 { + // Unlimited versions + maxVersions = math.MaxInt32 + } + } + + // Start transaction to update file + fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) + } + + err = fc.UpgradePlaceholder(ctx, filePrivate.Model, session.Props.LastModified, session.EntityID, entityType) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update placeholder file", err) + } + + // Remove metadata that are defined in upload session + err = fc.RemoveMetadata(ctx, filePrivate.Model, MetadataUploadSessionID, ThumbDisabledKey) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update placeholder metadata", err) + } + + if len(session.Props.Metadata) > 0 { + if err := fc.UpsertMetadata(ctx, filePrivate.Model, session.Props.Metadata, nil); err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to upsert placeholder metadata", err) + } + } + + diff, err := fc.CapEntities(ctx, filePrivate.Model, owner, maxVersions, entityType) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to cap version entities", err) + } + tx.AppendStorageDiff(diff) + + if entityType == types.EntityTypeVersion { + // If updating version entity, we need to cap all existing thumbnail entity to let it re-generate. + diff, err = fc.CapEntities(ctx, filePrivate.Model, owner, 0, types.EntityTypeThumbnail) + if err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to cap thumbnail entities", err) + } + + tx.AppendStorageDiff(diff) + } + + if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit file change", err) + } + + // Unlock file + if session.LockToken != "" { + if err := f.ls.Unlock(time.Now(), session.LockToken); err != nil { + return nil, serializer.NewError(serializer.CodeLockConflict, "Failed to unlock file", err) + } + } + + file, err = f.Get(ctx, session.Props.Uri, WithFileEntities()) + if err != nil { + return nil, fmt.Errorf("failed to get updated file: %w", err) + } + + return file, nil +} + +// This function will be used: +// - File still locked by uplaod session +// - File unlocked, upload session valid +// - File unlocked, upload session not valid +func (f *DBFS) CancelUploadSession(ctx context.Context, path *fs.URI, sessionID string, session *fs.UploadSession) ([]fs.Entity, error) { + // Get placeholder file + file, err := f.Get(ctx, path, WithFileEntities()) + if err != nil { + return nil, fmt.Errorf("failed to get placeholder file: %w", err) + } + + filePrivate := file.(*File) + + // Make sure presented upload session is valid + if session != nil && (session.UID != f.user.ID || session.FileID != file.ID()) { + return nil, serializer.NewError(serializer.CodeNotFound, "Upload session not found", nil) + } + + // Confirm locks on placeholder file + if session != nil && session.LockToken != "" { + release, ls, err := f.ConfirmLock(ctx, file, file.Uri(false), session.LockToken) + if err == nil { + release() + ctx = fs.LockSessionToContext(ctx, ls) + } + } + + if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && filePrivate.OwnerID() != f.user.ID { + return nil, fs.ErrOwnerOnly + } + + // Lock file + ls, err := f.acquireByPath(ctx, -1, f.user, true, fs.LockApp(fs.ApplicationUpload), + &LockByPath{filePrivate.Uri(true), filePrivate, filePrivate.Type(), ""}) + defer func() { _ = f.Release(ctx, ls) }() + ctx = fs.LockSessionToContext(ctx, ls) + if err != nil { + return nil, err + } + + // Find placeholder entity + var entity fs.Entity + for _, e := range filePrivate.Entities() { + if sid := e.UploadSessionID(); sid != nil && sid.String() == sessionID { + entity = e + break + } + } + + // Remove upload session metadata + if err := f.fileClient.RemoveMetadata(ctx, filePrivate.Model, MetadataUploadSessionID, ThumbDisabledKey); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to remove upload session metadata", err) + } + + if entity == nil { + // Given upload session does not exist + return nil, nil + } + + if session != nil && session.LockToken != "" { + defer func() { + if err := f.ls.Unlock(time.Now(), session.LockToken); err != nil { + f.l.Warning("Failed to unlock file %q: %s", filePrivate.Uri(true).String(), err) + } + }() + } + + if len(filePrivate.Entities()) == 1 { + // Only one placeholder entity, just delete this file + return f.Delete(ctx, []*fs.URI{path}) + } + + // Delete place holder entity + storageDiff, err := f.deleteEntity(ctx, filePrivate, entity.ID()) + if err != nil { + return nil, fmt.Errorf("failed to delete placeholder entity: %w", err) + } + + if err := f.userClient.ApplyStorageDiff(ctx, storageDiff); err != nil { + return nil, fmt.Errorf("failed to apply storage diff: %w", err) + } + + return nil, nil +} diff --git a/pkg/filemanager/fs/dbfs/validator.go b/pkg/filemanager/fs/dbfs/validator.go new file mode 100644 index 00000000..71337499 --- /dev/null +++ b/pkg/filemanager/fs/dbfs/validator.go @@ -0,0 +1,88 @@ +package dbfs + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "strings" +) + +const MaxFileNameLength = 256 + +// validateFileName validates the file name. +func validateFileName(name string) error { + if len(name) >= MaxFileNameLength || len(name) == 0 { + return fmt.Errorf("length of name must be between 1 and 255") + } + + if strings.ContainsAny(name, "\\/:*?\"<>|") { + return fmt.Errorf("name contains illegal characters") + } + + if name == "." || name == ".." { + return fmt.Errorf("name cannot be only dot") + } + + return nil +} + +// validateExtension validates the file extension. +func validateExtension(name string, policy *ent.StoragePolicy) error { + // 不需要验证 + if len(policy.Settings.FileType) == 0 { + return nil + } + + if !util.IsInExtensionList(policy.Settings.FileType, name) { + return fmt.Errorf("file extension is not allowed") + } + + return nil +} + +// validateFileSize validates the file size. +func validateFileSize(size int64, policy *ent.StoragePolicy) error { + if policy.MaxSize == 0 { + return nil + } else if size > policy.MaxSize { + return fs.ErrFileSizeTooBig + } + + return nil +} + +// validateNewFile validates the upload request. +func validateNewFile(fileName string, size int64, policy *ent.StoragePolicy) error { + if err := validateFileName(fileName); err != nil { + return err + } + + if err := validateExtension(fileName, policy); err != nil { + return err + } + + if err := validateFileSize(size, policy); err != nil { + return err + } + + return nil +} + +func (f *DBFS) validateUserCapacity(ctx context.Context, size int64, u *ent.User) error { + capacity, err := f.Capacity(ctx, u) + if err != nil { + return fmt.Errorf("failed to get user capacity: %s", err) + } + + return f.validateUserCapacityRaw(ctx, size, capacity) +} + +// validateUserCapacityRaw validates the user capacity, but does not fetch the capacity. +func (f *DBFS) validateUserCapacityRaw(ctx context.Context, size int64, capacity *fs.Capacity) error { + if capacity.Used+size > capacity.Total { + return fs.ErrInsufficientCapacity + } + return nil +} diff --git a/pkg/filemanager/fs/fs.go b/pkg/filemanager/fs/fs.go new file mode 100644 index 00000000..12f505a5 --- /dev/null +++ b/pkg/filemanager/fs/fs.go @@ -0,0 +1,763 @@ +package fs + +import ( + "context" + "encoding/gob" + "errors" + "fmt" + "io" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gofrs/uuid" +) + +type FsCapability int + +const ( + FsCapabilityList = FsCapability(iota) +) + +var ( + ErrDirectLinkInvalid = serializer.NewError(serializer.CodeNotFound, "Direct link invalid", nil) + ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil) + ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil) + ErrFileDeleted = serializer.NewError(serializer.CodeFileDeleted, "File deleted", nil) + ErrEntityNotExist = serializer.NewError(serializer.CodeEntityNotExist, "Entity not exist", nil) + ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil) + ErrNotSupportedAction = serializer.NewError(serializer.CodeNoPermissionErr, "Not supported action", nil) + ErrLockConflict = serializer.NewError(serializer.CodeLockConflict, "Lock conflict", nil) + ErrLockExpired = serializer.NewError(serializer.CodeLockConflict, "Lock expired", nil) + ErrModified = serializer.NewError(serializer.CodeConflict, "Object conflict", nil) + ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil) + ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil) + ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil) + ErrStaleVersion = serializer.NewError(serializer.CodeStaleVersion, "File is updated during your edit", nil) + ErrOwnerOnly = serializer.NewError(serializer.CodeOwnerOnly, "Only owner or administrator can perform this action", nil) + ErrArchiveSrcSizeTooBig = ErrFileSizeTooBig.WithError(fmt.Errorf("total size of to-be compressed file exceed group limit (%w)", queue.CriticalErr)) +) + +type ( + FileSystem interface { + LockSystem + UploadManager + FileManager + // Recycle recycles a DBFS and its generated resources. + Recycle() + // Capacity returns the storage capacity of the filesystem. + Capacity(ctx context.Context, u *ent.User) (*Capacity, error) + // CheckCapability checks if the filesystem supports given capability. + CheckCapability(ctx context.Context, uri *URI, opts ...Option) error + // StaleEntities returns all stale entities of given IDs. If no ID is given, all + // potential stale entities will be returned. + StaleEntities(ctx context.Context, entities ...int) ([]Entity, error) + // AllFilesInTrashBin returns all files in trash bin, despite owner. + AllFilesInTrashBin(ctx context.Context, opts ...Option) (*ListFileResult, error) + // Walk walks through all files under given path with given depth limit. + Walk(ctx context.Context, path *URI, depth int, walk WalkFunc, opts ...Option) error + // SharedAddressTranslation translates a path that potentially contain shared symbolic to a real address. + SharedAddressTranslation(ctx context.Context, path *URI, opts ...Option) (File, *URI, error) + // ExecuteNavigatorHooks executes hooks of given type on a file for navigator based custom hooks. + ExecuteNavigatorHooks(ctx context.Context, hookType HookType, file File) error + } + + FileManager interface { + // Get returns a file by its path. + Get(ctx context.Context, path *URI, opts ...Option) (File, error) + // Create creates a file. + Create(ctx context.Context, path *URI, fileType types.FileType, opts ...Option) (File, error) + // List lists files under give path. + List(ctx context.Context, path *URI, opts ...Option) (File, *ListFileResult, error) + // Rename renames a file. + Rename(ctx context.Context, path *URI, newName string) (File, error) + // Move moves files to dst. + MoveOrCopy(ctx context.Context, path []*URI, dst *URI, isCopy bool) error + // Delete performs hard-delete for given paths, return newly generated stale entities in this delete operation. + Delete(ctx context.Context, path []*URI, opts ...Option) ([]Entity, error) + // GetEntitiesFromFileID returns all entities of a given file. + GetEntity(ctx context.Context, entityID int) (Entity, error) + // UpsertMetadata update or insert metadata of a file. + PatchMetadata(ctx context.Context, path []*URI, metas ...MetadataPatch) error + // SoftDelete moves given files to trash bin. + SoftDelete(ctx context.Context, path ...*URI) error + // Restore restores given files from trash bin to its original location. + Restore(ctx context.Context, path ...*URI) error + // VersionControl performs version control on given file. + // - `delete` is false: set version as current version; + // - `delete` is true: delete version. + VersionControl(ctx context.Context, path *URI, versionId int, delete bool) error + } + + UploadManager interface { + // PrepareUpload prepares an upload session. It performs validation on upload request and returns a placeholder + // file if needed. + PrepareUpload(ctx context.Context, req *UploadRequest, opts ...Option) (*UploadSession, error) + // CompleteUpload completes an upload session. + CompleteUpload(ctx context.Context, session *UploadSession) (File, error) + // CancelUploadSession cancels an upload session. Delete the placeholder file if no other entity is created. + CancelUploadSession(ctx context.Context, path *URI, sessionID string, session *UploadSession) ([]Entity, error) + } + + LockSystem interface { + // ConfirmLock confirms if a lock token is valid on given URI. + ConfirmLock(ctx context.Context, ancestor File, uri *URI, token ...string) (func(), LockSession, error) + // Lock locks a file. If zeroDepth is true, only the file itself will be locked. Ancestor is closest ancestor + // of the file that will be locked, if the given uri is an existing file, ancestor will be itself. + // `token` is optional and can be used if the requester need to explicitly specify a token. + Lock(ctx context.Context, d time.Duration, requester *ent.User, zeroDepth bool, application lock.Application, + uri *URI, token string) (LockSession, error) + // Unlock unlocks files by given tokens. + Unlock(ctx context.Context, tokens ...string) error + // Refresh refreshes a lock. + Refresh(ctx context.Context, d time.Duration, token string) (lock.LockDetails, error) + } + + StatelessUploadManager interface { + // PrepareUpload prepares the upload on the node. + PrepareUpload(ctx context.Context, args *StatelessPrepareUploadService) (*StatelessPrepareUploadResponse, error) + // CompleteUpload completes the upload on the node. + CompleteUpload(ctx context.Context, args *StatelessCompleteUploadService) error + // OnUploadFailed handles the failed upload on the node. + OnUploadFailed(ctx context.Context, args *StatelessOnUploadFailedService) error + // CreateFile creates a file on the node. + CreateFile(ctx context.Context, args *StatelessCreateFileService) error + } + + WalkFunc func(file File, level int) error + + File interface { + IsNil() bool + ID() int + Name() string + DisplayName() string + Ext() string + Type() types.FileType + Size() int64 + UpdatedAt() time.Time + CreatedAt() time.Time + Metadata() map[string]string + // Uri returns the URI of the file. + Uri(isRoot bool) *URI + Owner() *ent.User + OwnerID() int + // RootUri return the URI of the user root file under owner's view. + RootUri() *URI + Entities() []Entity + PrimaryEntity() Entity + PrimaryEntityID() int + Shared() bool + IsSymbolic() bool + PolicyID() (id int) + ExtendedInfo() *FileExtendedInfo + FolderSummary() *FolderSummary + Capabilities() *boolset.BooleanSet + } + + Entities []Entity + Entity interface { + ID() int + Type() types.EntityType + Size() int64 + UpdatedAt() time.Time + CreatedAt() time.Time + Source() string + ReferenceCount() int + PolicyID() int + UploadSessionID() *uuid.UUID + CreatedBy() *ent.User + Model() *ent.Entity + } + + FileExtendedInfo struct { + StoragePolicy *ent.StoragePolicy + StorageUsed int64 + Shares []*ent.Share + EntityStoragePolicies map[int]*ent.StoragePolicy + } + + FolderSummary struct { + Size int64 `json:"size"` + Files int `json:"files"` + Folders int `json:"folders"` + Completed bool `json:"completed"` // whether the size calculation is completed + CalculatedAt time.Time `json:"calculated_at"` + } + + MetadataPatch struct { + Key string `json:"key" binding:"required"` + Value string `json:"value"` + Private bool `json:"private" binding:"ne=true"` + Remove bool `json:"remove"` + } + + // ListFileResult result of listing files. + ListFileResult struct { + Files []File + Parent File + Pagination *inventory.PaginationResults + Props *NavigatorProps + ContextHint *uuid.UUID + RecursionLimitReached bool + MixedType bool + SingleFileView bool + StoragePolicy *ent.StoragePolicy + } + + // NavigatorProps is the properties of current filesystem. + NavigatorProps struct { + // Supported capabilities of the navigator. + Capability *boolset.BooleanSet `json:"capability"` + // MaxPageSize is the maximum page size of the navigator. + MaxPageSize int `json:"max_page_size"` + // OrderByOptions is the supported order by options of the navigator. + OrderByOptions []string `json:"order_by_options"` + // OrderDirectionOptions is the supported order direction options of the navigator. + OrderDirectionOptions []string `json:"order_direction_options"` + } + + // UploadCredential for uploading files in client side. + UploadCredential struct { + SessionID string `json:"session_id"` + ChunkSize int64 `json:"chunk_size"` // 分块大小,0 为部分快 + Expires int64 `json:"expires"` // 上传凭证过期时间, Unix 时间戳 + UploadURLs []string `json:"upload_urls,omitempty"` + Credential string `json:"credential,omitempty"` + UploadID string `json:"uploadID,omitempty"` + Callback string `json:"callback,omitempty"` // 回调地址 + Uri string `json:"uri,omitempty"` // 存储路径 + AccessKey string `json:"ak,omitempty"` + KeyTime string `json:"keyTime,omitempty"` // COS用有效期 + CompleteURL string `json:"completeURL,omitempty"` + StoragePolicy *ent.StoragePolicy + CallbackSecret string `json:"callback_secret,omitempty"` + MimeType string `json:"mime_type,omitempty"` // Expected mimetype + UploadPolicy string `json:"upload_policy,omitempty"` // Upyun upload policy + } + + // UploadSession stores the information of an upload session, used in server side. + UploadSession struct { + UID int // 发起者 + Policy *ent.StoragePolicy + FileID int // ID of the placeholder file + EntityID int // ID of the new entity + Callback string // 回调 URL 地址 + CallbackSecret string // Callback secret + UploadID string // Multi-part upload ID + UploadURL string + Credential string + ChunkSize int64 + SentinelTaskID int + NewFileCreated bool // If new file is created for this session + + LockToken string // Token of the locked placeholder file + Props *UploadProps + } + + // UploadProps properties of an upload session/request. + UploadProps struct { + Uri *URI + Size int64 + UploadSessionID string + PreferredStoragePolicy int + SavePath string + LastModified *time.Time + MimeType string + Metadata map[string]string + PreviousVersion string + // EntityType is the type of the entity to be created. If not set, a new file will be created + // with a default version entity. This will be set in update request for existing files. + EntityType *types.EntityType + ExpireAt time.Time + } + + // FsOption options for underlying file system. + FsOption struct { + Page int // Page number when listing files. + PageSize int // Size of pages when listing files. + OrderBy string + OrderDirection string + UploadRequest *UploadRequest + UnlinkOnly bool + UploadSession *UploadSession + DownloadSpeed int64 + IsDownload bool + Expire *time.Time + Entity Entity + IsThumb bool + EntityType *types.EntityType + EntityTypeNil bool + SkipSoftDelete bool + SysSkipSoftDelete bool + Metadata map[string]string + ArchiveCompression bool + ProgressFunc + MaxArchiveSize int64 + DryRun CreateArchiveDryRunFunc + Policy *ent.StoragePolicy + Node StatelessUploadManager + StatelessUserID int + NoCache bool + } + + // Option 发送请求的额外设置 + Option interface { + Apply(any) + } + + OptionFunc func(*FsOption) + + // Ctx keys used to detect user canceled operation. + UserCancelCtx struct{} + GinCtx struct{} + + // Capacity describes the capacity of a filesystem. + Capacity struct { + Total int64 `json:"total"` + Used int64 `json:"used"` + } + + FileCapacity int + + LockSession interface { + LastToken() string + } + + HookType int + + CreateArchiveDryRunFunc func(name string, e Entity) + + StatelessPrepareUploadService struct { + UploadRequest *UploadRequest `json:"upload_request" binding:"required"` + UserID int `json:"user_id"` + } + StatelessCompleteUploadService struct { + UploadSession *UploadSession `json:"upload_session" binding:"required"` + UserID int `json:"user_id"` + } + StatelessOnUploadFailedService struct { + UploadSession *UploadSession `json:"upload_session" binding:"required"` + UserID int `json:"user_id"` + } + StatelessCreateFileService struct { + Path string `json:"path" binding:"required"` + Type types.FileType `json:"type" binding:"required"` + UserID int `json:"user_id"` + } + StatelessPrepareUploadResponse struct { + Session *UploadSession + Req *UploadRequest + } + + PrepareRelocateRes struct { + Entities map[int]*RelocateEntity `json:"entities,omitempty"` + LockToken string `json:"lock_token,omitempty"` + Policy *ent.StoragePolicy `json:"policy,omitempty"` + } + + RelocateEntity struct { + SrcEntity *ent.Entity `json:"src_entity"` + FileUri *URI `json:"file_uri,omitempty"` + NewSavePath string `json:"new_save_path"` + ParentFiles []int `json:"parent_files"` + PrimaryEntityParentFiles []int `json:"primary_entity_parent_files"` + } +) + +const ( + FileCapacityPreview FileCapacity = iota + FileCapacityEnter + FileCapacityDownload + FileCapacityRename + FileCapacityCopy + FileCapacityMove +) + +const ( + HookTypeBeforeDownload = HookType(iota) +) + +func (p *UploadProps) Copy() *UploadProps { + newProps := *p + return &newProps +} + +func (f OptionFunc) Apply(o any) { + f(o.(*FsOption)) +} + +// ==================== FS Options ==================== + +// WithUploadSession sets upload session for manager. +func WithUploadSession(s *UploadSession) Option { + return OptionFunc(func(o *FsOption) { + o.UploadSession = s + }) +} + +// WithPageSize limit items in a page for listing files. +func WithPageSize(s int) Option { + return OptionFunc(func(o *FsOption) { + o.PageSize = s + }) +} + +// WithPage set page number for listing files. +func WithPage(p int) Option { + return OptionFunc(func(o *FsOption) { + o.Page = p + }) +} + +// WithOrderBy set order by for listing files. +func WithOrderBy(p string) Option { + return OptionFunc(func(o *FsOption) { + o.OrderBy = p + }) +} + +// WithOrderDirection set order direction for listing files. +func WithOrderDirection(p string) Option { + return OptionFunc(func(o *FsOption) { + o.OrderDirection = p + }) +} + +// WithUploadRequest set upload request for uploading files. +func WithUploadRequest(p *UploadRequest) Option { + return OptionFunc(func(o *FsOption) { + o.UploadRequest = p + }) +} + +// WithProgressFunc set progress function for manager. +func WithProgressFunc(p ProgressFunc) Option { + return OptionFunc(func(o *FsOption) { + o.ProgressFunc = p + }) +} + +// WithUnlinkOnly set unlink only for unlinking files. +func WithUnlinkOnly(p bool) Option { + return OptionFunc(func(o *FsOption) { + o.UnlinkOnly = p + }) +} + +// WithDownloadSpeed sets download speed limit for manager. +func WithDownloadSpeed(speed int64) Option { + return OptionFunc(func(o *FsOption) { + o.DownloadSpeed = speed + }) +} + +func WithIsDownload(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.IsDownload = b + }) +} + +// WithSysSkipSoftDelete sets whether to skip soft delete without checking +// file ownership. +func WithSysSkipSoftDelete(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.SysSkipSoftDelete = b + }) +} + +// WithNoCache sets whether to disable cache for entity's URL. +func WithNoCache(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.NoCache = b + }) +} + +// WithUrlExpire sets expire time for entity's URL. +func WithUrlExpire(t *time.Time) Option { + return OptionFunc(func(o *FsOption) { + o.Expire = t + }) +} + +// WithEntity sets entity for manager. +func WithEntity(e Entity) Option { + return OptionFunc(func(o *FsOption) { + o.Entity = e + }) +} + +// WithPolicy sets storage policy overwrite for manager. +func WithPolicy(p *ent.StoragePolicy) Option { + return OptionFunc(func(o *FsOption) { + o.Policy = p + }) +} + +// WithUseThumb sets whether entity's URL is used for thumbnail. +func WithUseThumb(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.IsThumb = b + }) +} + +// WithEntityType sets entity type for manager. +func WithEntityType(t types.EntityType) Option { + return OptionFunc(func(o *FsOption) { + o.EntityType = &t + }) +} + +// WithNoEntityType sets entity type to nil for manager. +func WithNoEntityType() Option { + return OptionFunc(func(o *FsOption) { + o.EntityTypeNil = true + }) +} + +// WithSkipSoftDelete sets whether to skip soft delete. +func WithSkipSoftDelete(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.SkipSoftDelete = b + }) +} + +// WithMetadata sets metadata for file creation. +func WithMetadata(m map[string]string) Option { + return OptionFunc(func(o *FsOption) { + o.Metadata = m + }) +} + +// WithArchiveCompression sets whether to compress files in archive. +func WithArchiveCompression(b bool) Option { + return OptionFunc(func(o *FsOption) { + o.ArchiveCompression = b + }) +} + +// WithMaxArchiveSize sets maximum size of to be archived file or to-be decompressed +// size, 0 for unlimited. +func WithMaxArchiveSize(s int64) Option { + return OptionFunc(func(o *FsOption) { + o.MaxArchiveSize = s + }) +} + +// WithDryRun sets whether to perform dry run. +func WithDryRun(b CreateArchiveDryRunFunc) Option { + return OptionFunc(func(o *FsOption) { + o.DryRun = b + }) +} + +// WithNode sets node for stateless upload manager. +func WithNode(n StatelessUploadManager) Option { + return OptionFunc(func(o *FsOption) { + o.Node = n + }) +} + +// WithStatelessUserID sets stateless user ID for manager. +func WithStatelessUserID(id int) Option { + return OptionFunc(func(o *FsOption) { + o.StatelessUserID = id + }) +} + +type WriteMode int + +const ( + ModeNone WriteMode = 0x00000 + ModeOverwrite WriteMode = 0x00001 + // Deprecated + ModeNop WriteMode = 0x00004 +) + +type ( + ProgressFunc func(current, diff int64, total int64) + UploadRequest struct { + Props *UploadProps + + Mode WriteMode + File io.ReadCloser `json:"-"` + Seeker io.Seeker `json:"-"` + Offset int64 + ProgressFunc `json:"-"` + + read int64 + } +) + +func (file *UploadRequest) Read(p []byte) (n int, err error) { + if file.File != nil { + n, err = file.File.Read(p) + file.read += int64(n) + if file.ProgressFunc != nil { + file.ProgressFunc(file.read, int64(n), file.Props.Size) + } + + return + } + + return 0, io.EOF +} + +func (file *UploadRequest) Close() error { + if file.File != nil { + return file.File.Close() + } + + return nil +} + +func (file *UploadRequest) Seek(offset int64, whence int) (int64, error) { + if file.Seekable() { + previous := file.read + o, err := file.Seeker.Seek(offset, whence) + file.read = o + if file.ProgressFunc != nil { + file.ProgressFunc(o, file.read-previous, file.Props.Size) + } + return o, err + } + + return 0, errors.New("no seeker") +} + +func (file *UploadRequest) Seekable() bool { + return file.Seeker != nil +} + +func init() { + gob.Register(UploadSession{}) + gob.Register(FolderSummary{}) +} + +type ApplicationType string + +const ( + ApplicationCreate ApplicationType = "create" + ApplicationRename ApplicationType = "rename" + ApplicationSetPermission ApplicationType = "setPermission" + ApplicationMoveCopy ApplicationType = "moveCopy" + ApplicationUpload ApplicationType = "upload" + ApplicationUpdateMetadata ApplicationType = "updateMetadata" + ApplicationDelete ApplicationType = "delete" + ApplicationSoftDelete ApplicationType = "softDelete" + ApplicationDAV ApplicationType = "dav" + ApplicationVersionControl ApplicationType = "versionControl" + ApplicationViewer ApplicationType = "viewer" + ApplicationMount ApplicationType = "mount" + ApplicationRelocate ApplicationType = "relocate" +) + +func LockApp(a ApplicationType) lock.Application { + return lock.Application{Type: string(a)} +} + +type LockSessionCtxKey struct{} + +// LockSessionToContext stores lock session to context. +func LockSessionToContext(ctx context.Context, session LockSession) context.Context { + return context.WithValue(ctx, LockSessionCtxKey{}, session) +} + +func FindDesiredEntity(file File, version string, hasher hashid.Encoder, entityType *types.EntityType) (bool, Entity) { + if version == "" { + return true, file.PrimaryEntity() + } + + requestedVersion, err := hasher.Decode(version, hashid.EntityID) + if err != nil { + return false, nil + } + + hasVersions := false + for _, entity := range file.Entities() { + if entity.Type() == types.EntityTypeVersion { + hasVersions = true + } + + if entity.ID() == requestedVersion && (entityType == nil || *entityType == entity.Type()) { + return true, entity + } + } + + // Happy path for: File has no versions, requested version is empty entity + if !hasVersions && requestedVersion == 0 { + return true, file.PrimaryEntity() + } + + return false, nil +} + +type DbEntity struct { + model *ent.Entity +} + +func NewEntity(model *ent.Entity) Entity { + return &DbEntity{model: model} +} + +func (e *DbEntity) ID() int { + return e.model.ID +} + +func (e *DbEntity) Type() types.EntityType { + return types.EntityType(e.model.Type) +} + +func (e *DbEntity) Size() int64 { + return e.model.Size +} + +func (e *DbEntity) UpdatedAt() time.Time { + return e.model.UpdatedAt +} + +func (e *DbEntity) CreatedAt() time.Time { + return e.model.CreatedAt +} + +func (e *DbEntity) CreatedBy() *ent.User { + return e.model.Edges.User +} + +func (e *DbEntity) Source() string { + return e.model.Source +} + +func (e *DbEntity) ReferenceCount() int { + return e.model.ReferenceCount +} + +func (e *DbEntity) PolicyID() int { + return e.model.StoragePolicyEntities +} + +func (e *DbEntity) UploadSessionID() *uuid.UUID { + return e.model.UploadSessionID +} + +func (e *DbEntity) Model() *ent.Entity { + return e.model +} + +func NewEmptyEntity(u *ent.User) Entity { + return &DbEntity{ + model: &ent.Entity{ + UpdatedAt: time.Now(), + ReferenceCount: 1, + CreatedAt: time.Now(), + Edges: ent.EntityEdges{ + User: u, + }, + }, + } +} diff --git a/pkg/filemanager/fs/mime/mime.go b/pkg/filemanager/fs/mime/mime.go new file mode 100644 index 00000000..751fd5c5 --- /dev/null +++ b/pkg/filemanager/fs/mime/mime.go @@ -0,0 +1,40 @@ +package mime + +import ( + "context" + "encoding/json" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "mime" + "path" +) + +type MimeDetector interface { + // TypeByName returns the mime type by file name. + TypeByName(ext string) string +} + +type mimeDetector struct { + mapping map[string]string +} + +func NewMimeDetector(ctx context.Context, settings setting.Provider, l logging.Logger) MimeDetector { + mappingStr := settings.MimeMapping(ctx) + mapping := make(map[string]string) + if err := json.Unmarshal([]byte(mappingStr), &mapping); err != nil { + l.Error("Failed to unmarshal mime mapping: %s, fallback to empty mapping", err) + } + + return &mimeDetector{ + mapping: mapping, + } +} + +func (d *mimeDetector) TypeByName(p string) string { + ext := path.Ext(p) + if m, ok := d.mapping[ext]; ok { + return m + } + + return mime.TypeByExtension(ext) +} diff --git a/pkg/filemanager/fs/uri.go b/pkg/filemanager/fs/uri.go new file mode 100644 index 00000000..0db4460d --- /dev/null +++ b/pkg/filemanager/fs/uri.go @@ -0,0 +1,421 @@ +package fs + +import ( + "encoding/json" + "fmt" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" +) + +const ( + Separator = "/" +) + +const ( + QuerySearchName = "name" + QuerySearchNameOpOr = "use_or" + QuerySearchMetadataPrefix = "meta_" + QuerySearchCaseFolding = "case_folding" + QuerySearchType = "type" + QuerySearchTypeCategory = "category" + QuerySearchSizeGte = "size_gte" + QuerySearchSizeLte = "size_lte" + QuerySearchCreatedGte = "created_gte" + QuerySearchCreatedLte = "created_lte" + QuerySearchUpdatedGte = "updated_gte" + QuerySearchUpdatedLte = "updated_lte" +) + +type URI struct { + U *url.URL +} + +func NewUriFromString(u string) (*URI, error) { + raw, err := url.Parse(u) + if err != nil { + return nil, fmt.Errorf("failed to parse uri: %w", err) + } + + if raw.Scheme != constants.CloudreveScheme { + return nil, fmt.Errorf("unknown scheme: %s", raw.Scheme) + } + + if strings.HasSuffix(raw.Path, Separator) { + raw.Path = strings.TrimSuffix(raw.Path, Separator) + } + + return &URI{U: raw}, nil +} + +func NewUriFromStrings(u ...string) ([]*URI, error) { + res := make([]*URI, 0, len(u)) + for _, uri := range u { + fsUri, err := NewUriFromString(uri) + if err != nil { + return nil, err + } + + res = append(res, fsUri) + } + + return res, nil +} + +func (u *URI) UnmarshalBinary(text []byte) error { + raw, err := url.Parse(string(text)) + if err != nil { + return fmt.Errorf("failed to parse uri: %w", err) + } + + u.U = raw + return nil +} + +func (u *URI) MarshalBinary() ([]byte, error) { + return u.U.MarshalBinary() +} + +func (u *URI) MarshalJSON() ([]byte, error) { + r := map[string]string{ + "uri": u.String(), + } + return json.Marshal(r) +} + +func (u *URI) UnmarshalJSON(text []byte) error { + r := make(map[string]string) + err := json.Unmarshal(text, &r) + if err != nil { + return err + } + + u.U, err = url.Parse(r["uri"]) + if err != nil { + return err + } + + return nil +} + +func (u *URI) String() string { + return u.U.String() +} + +func (u *URI) Name() string { + return path.Base(u.Path()) +} + +func (u *URI) Dir() string { + return path.Dir(u.Path()) +} + +func (u *URI) Elements() []string { + res := strings.Split(u.PathTrimmed(), Separator) + if len(res) == 1 && res[0] == "" { + return nil + } + + return res +} + +func (u *URI) ID(defaultUid string) string { + if u.U.User == nil { + if u.FileSystem() != constants.FileSystemShare { + return defaultUid + } + return "" + } + + return u.U.User.Username() +} + +func (u *URI) Path() string { + p := u.U.Path + if !strings.HasPrefix(u.U.Path, Separator) { + p = Separator + u.U.Path + } + + return path.Clean(p) +} + +func (u *URI) PathTrimmed() string { + return strings.TrimPrefix(u.Path(), Separator) +} + +func (u *URI) Password() string { + if u.U.User == nil { + return "" + } + + pwd, _ := u.U.User.Password() + return pwd +} + +func (u *URI) Join(elem ...string) *URI { + newUrl, _ := url.Parse(u.U.String()) + return &URI{U: newUrl.JoinPath(lo.Map(elem, func(s string, i int) string { + return PathEscape(s) + })...)} +} + +// Join path with raw string +func (u *URI) JoinRaw(elem string) *URI { + return u.Join(strings.Split(strings.TrimPrefix(elem, Separator), Separator)...) +} + +func (u *URI) DirUri() *URI { + newUrl, _ := url.Parse(u.U.String()) + newUrl.Path = path.Dir(newUrl.Path) + + return &URI{U: newUrl} +} + +func (u *URI) Root() *URI { + newUrl, _ := url.Parse(u.U.String()) + newUrl.Path = Separator + newUrl.RawQuery = "" + + return &URI{U: newUrl} +} + +func (u *URI) SetQuery(q string) *URI { + newUrl, _ := url.Parse(u.U.String()) + newUrl.RawQuery = q + return &URI{U: newUrl} +} + +func (u *URI) IsSame(p *URI, uid string) bool { + return p.FileSystem() == u.FileSystem() && p.ID(uid) == u.ID(uid) && u.Path() == p.Path() +} + +// Rebased returns a new URI with the path rebased to the given base URI. It is +// commnly used in WebDAV address translation with shared folder symlink. +func (u *URI) Rebase(target, base *URI) *URI { + targetPath := target.Path() + basePath := base.Path() + rebasedPath := strings.TrimPrefix(targetPath, basePath) + + newUrl, _ := url.Parse(u.U.String()) + newUrl.Path = path.Join(newUrl.Path, rebasedPath) + return &URI{U: newUrl} +} + +func (u *URI) FileSystem() constants.FileSystemType { + return constants.FileSystemType(strings.ToLower(u.U.Host)) +} + +// SearchParameters returns the search parameters from the URI. If no search parameters are present, nil is returned. +func (u *URI) SearchParameters() *inventory.SearchFileParameters { + q := u.U.Query() + res := &inventory.SearchFileParameters{ + Metadata: make(map[string]string), + } + withSearch := false + + if names, ok := q[QuerySearchName]; ok { + withSearch = len(names) > 0 + res.Name = names + } + + if _, ok := q[QuerySearchNameOpOr]; ok { + res.NameOperatorOr = true + } + + if _, ok := q[QuerySearchCaseFolding]; ok { + res.CaseFolding = true + } + + if v, ok := q[QuerySearchTypeCategory]; ok { + res.Category = v[0] + withSearch = withSearch || len(res.Category) > 0 + } + + if t, ok := q[QuerySearchType]; ok { + fileType := types.FileTypeFromString(t[0]) + res.Type = &fileType + withSearch = true + } + + for k, v := range q { + if strings.HasPrefix(k, QuerySearchMetadataPrefix) { + res.Metadata[strings.TrimPrefix(k, QuerySearchMetadataPrefix)] = v[0] + withSearch = true + } + } + + if v, ok := q[QuerySearchSizeGte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + res.SizeGte = limit + withSearch = true + } + } + + if v, ok := q[QuerySearchSizeLte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + res.SizeLte = limit + withSearch = true + } + } + + if v, ok := q[QuerySearchCreatedGte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + limit := time.Unix(limit, 0) + res.CreatedAtGte = &limit + withSearch = true + } + } + + if v, ok := q[QuerySearchCreatedLte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + limit := time.Unix(limit, 0) + res.CreatedAtLte = &limit + withSearch = true + } + } + + if v, ok := q[QuerySearchUpdatedGte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + limit := time.Unix(limit, 0) + res.UpdatedAtGte = &limit + withSearch = true + } + } + + if v, ok := q[QuerySearchUpdatedLte]; ok { + limit, err := strconv.ParseInt(v[0], 10, 64) + if err == nil { + limit := time.Unix(limit, 0) + res.UpdatedAtLte = &limit + withSearch = true + } + } + + if withSearch { + return res + } + + return nil +} + +// EqualOrIsDescendantOf returns true if the URI is equal to the given URI or if it is a descendant of the given URI. +func (u *URI) EqualOrIsDescendantOf(p *URI, uid string) bool { + prefix := p.Path() + if prefix[len(prefix)-1] != Separator[0] { + prefix += Separator + } + + return p.FileSystem() == u.FileSystem() && p.ID(uid) == u.ID(uid) && + (strings.HasPrefix(u.Path(), prefix) || u.Path() == p.Path()) +} + +func SearchCategoryFromString(s string) setting.SearchCategory { + switch s { + case "image": + return setting.CategoryImage + case "video": + return setting.CategoryVideo + case "audio": + return setting.CategoryAudio + case "document": + return setting.CategoryDocument + default: + return setting.CategoryUnknown + } +} + +func NewShareUri(id, password string) string { + if password != "" { + return fmt.Sprintf("%s://%s:%s@%s", constants.CloudreveScheme, id, password, constants.FileSystemShare) + } + return fmt.Sprintf("%s://%s@%s", constants.CloudreveScheme, id, constants.FileSystemShare) +} + +// PathEscape is same as url.PathEscape, with modifications to incoporate with JS encodeURI: +// encodeURI() escapes all characters except: +// +// A–Z a–z 0–9 - _ . ! ~ * ' ( ) +// ; / ? : @ & = + $ , # +func PathEscape(s string) string { + hexCount := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c) { + hexCount++ + } + } + + if hexCount == 0 { + return s + } + + var buf [64]byte + var t []byte + + required := len(s) + 2*hexCount + if required <= len(buf) { + t = buf[:required] + } else { + t = make([]byte, required) + } + + if hexCount == 0 { + copy(t, s) + for i := 0; i < len(s); i++ { + if s[i] == ' ' { + t[i] = '+' + } + } + return string(t) + } + + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case shouldEscape(c): + t[j] = '%' + t[j+1] = upperhex[c>>4] + t[j+2] = upperhex[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} + +const upperhex = "0123456789ABCDEF" + +// Return true if the specified character should be escaped when +// appearing in a URL string, according to RFC 3986. +// +// Please be informed that for now shouldEscape does not check all +// reserved characters correctly. See golang.org/issue/5684. +func shouldEscape(c byte) bool { + // §2.3 Unreserved characters (alphanum) + if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { + return false + } + + switch c { + case '-', '_', '.', '~', '!', '*', '\'', '(', ')', ';', '/', '?', ':', '@', '&', '=', '+', '$', ',', '#': // §2.3 Unreserved characters (mark) + return false + } + + // Everything else must be escaped. + return true +} diff --git a/pkg/filemanager/lock/memlock.go b/pkg/filemanager/lock/memlock.go new file mode 100644 index 00000000..219a47a2 --- /dev/null +++ b/pkg/filemanager/lock/memlock.go @@ -0,0 +1,467 @@ +package lock + +import ( + "container/heap" + "errors" + "strings" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +var ( + // ErrConfirmationFailed is returned by a LockSystem's Confirm method. + ErrConfirmationFailed = errors.New("memlock: confirmation failed") + ErrNoSuchLock = errors.New("memlock: no such lock") + ErrLocked = errors.New("memlock: locked") +) + +// LockSystem manages access to a collection of named resources. The elements +// in a lock name are separated by slash ('/', U+002F) characters, regardless +// of host operating system convention. +type LockSystem interface { + Create(now time.Time, details ...LockDetails) ([]string, error) + Unlock(now time.Time, tokens ...string) error + Confirm(now time.Time, requests LockInfo) (func(), string, error) + Refresh(now time.Time, duration time.Duration, token string) (LockDetails, error) +} + +// LockDetails are a lock's metadata. +type LockDetails struct { + // Root is the root resource name being locked. For a zero-depth lock, the + // root is the only resource being locked. + Root string + // Namespace of this lock. + Ns string + // Duration is the lock timeout. A negative duration means infinite. + Duration time.Duration + // Owner of this lock + Owner Owner + // ZeroDepth is whether the lock has zero depth. If it does not have zero + // depth, it has infinite depth. + ZeroDepth bool + // FileType is the type of the file being locked. This is used to display user-friendly error message. + Type types.FileType + // Optional, customize the token of the lock. + Token string +} + +func (d *LockDetails) Key() string { + return d.Ns + "/" + d.Root +} + +type Owner struct { + // Name of the application who are currently lock this. + Application Application `json:"application"` +} + +type Application struct { + Type string `json:"type"` + InnerXML string `json:"inner_xml,omitempty"` + ViewerID string `json:"viewer_id,omitempty"` +} + +// LockInfo is a lock confirmation request. +type LockInfo struct { + Ns string + Root string + Token []string +} + +type memLS struct { + l logging.Logger + hasher hashid.Encoder + mu sync.Mutex + byName map[string]map[string]*memLSNode + byToken map[string]*memLSNode + gen uint64 + // byExpiry only contains those nodes whose LockDetails have a finite + // Duration and are yet to expire. + byExpiry byExpiry +} + +// NewMemLS returns a new in-memory LockSystem. +func NewMemLS(hasher hashid.Encoder, l logging.Logger) LockSystem { + return &memLS{ + byName: make(map[string]map[string]*memLSNode), + byToken: make(map[string]*memLSNode), + hasher: hasher, + l: l, + } +} + +func (m *memLS) Confirm(now time.Time, request LockInfo) (func(), string, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + m.l.Debug("Memlock confirm: NS:%s, Root: %s, Token: %v", request.Ns, request.Root, request.Token) + n := m.lookup(request.Ns, request.Root, request.Token...) + if n == nil { + return nil, "", ErrConfirmationFailed + } + + m.hold(n) + return func() { + m.mu.Lock() + defer m.mu.Unlock() + m.unhold(n) + }, n.token, nil +} + +func (m *memLS) Refresh(now time.Time, duration time.Duration, token string) (LockDetails, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + m.l.Debug("Memlock refresh: Token: %s, Duration: %v", token, duration) + n := m.byToken[token] + if n == nil { + return LockDetails{}, ErrNoSuchLock + } + if n.held { + return LockDetails{}, ErrLocked + } + if n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } + n.details.Duration = duration + if n.details.Duration >= 0 { + n.expiry = now.Add(n.details.Duration) + heap.Push(&m.byExpiry, n) + } + return n.details, nil +} + +func (m *memLS) Create(now time.Time, details ...LockDetails) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + conflicts := make([]*ConflictDetail, 0) + locks := make([]*memLSNode, 0, len(details)) + for i, detail := range details { + // TODO: remove in production + // if !strings.Contains(detail.Ns, "my") && !strings.Contains(detail.Ns, "trash") { + // panic("invalid namespace") + // } + // Check lock conflicts + detail.Root = util.SlashClean(detail.Root) + m.l.Debug("Memlock create: NS:%s, Root: %s, Duration: %v, ZeroDepth: %v", detail.Ns, detail.Root, detail.Duration, detail.ZeroDepth) + conflict := m.canCreate(i, detail.Ns, detail.Root, detail.ZeroDepth) + if len(conflict) > 0 { + conflicts = append(conflicts, conflict...) + // Stop processing more locks since there's already conflicts + break + } else { + // Create locks + n := m.create(detail.Ns, detail.Root, detail.Token) + m.byToken[n.token] = n + n.details = detail + if n.details.Duration >= 0 { + n.expiry = now.Add(n.details.Duration) + heap.Push(&m.byExpiry, n) + } + locks = append(locks, n) + } + } + + if len(conflicts) > 0 { + for _, l := range locks { + m.remove(l) + } + + return nil, ConflictError(conflicts) + } + + return lo.Map(locks, func(item *memLSNode, index int) string { + return item.token + }), nil +} + +func (m *memLS) canCreate(index int, ns, name string, zeroDepth bool) []*ConflictDetail { + n := m.byName[ns] + if n == nil { + return nil + } + + conflicts := make([]*ConflictDetail, 0) + canCreate := walkToRoot(name, func(name0 string, first bool) bool { + n := m.byName[ns][name0] + if n == nil { + return true + } + + if first { + if n.token != "" { + // The target node is already locked. + conflicts = append(conflicts, n.toConflictDetail(index, m.hasher)) + return false + } + if !zeroDepth { + // The requested lock depth is infinite, and the fact that n exists + // (n != nil) means that a descendent of the target node is locked. + conflicts = append(conflicts, + lo.MapToSlice(n.childLocks, func(key string, value *memLSNode) *ConflictDetail { + return value.toConflictDetail(index, m.hasher) + }, + )...) + return false + } + } else if n.token != "" && !n.details.ZeroDepth { + // An ancestor of the target node is locked with infinite depth. + conflicts = append(conflicts, n.toConflictDetail(index, m.hasher)) + return false + } + return true + }) + + if !canCreate { + return conflicts + } + + return nil +} + +func (m *memLS) Unlock(now time.Time, tokens ...string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + conflicts := make([]*ConflictDetail, 0) + toBeRemoved := make([]*memLSNode, 0, len(tokens)) + + for i, token := range tokens { + n := m.byToken[token] + if n == nil { + return ErrNoSuchLock + } + if n.held { + conflicts = append(conflicts, n.toConflictDetail(i, m.hasher)) + } else { + toBeRemoved = append(toBeRemoved, n) + } + } + + if len(conflicts) > 0 { + return ConflictError(conflicts) + } + + for _, n := range toBeRemoved { + m.remove(n) + } + + return nil +} + +func (m *memLS) collectExpiredNodes(now time.Time) { + for len(m.byExpiry) > 0 { + if now.Before(m.byExpiry[0].expiry) { + break + } + m.remove(m.byExpiry[0]) + } +} + +func (m *memLS) create(ns, name, token string) (ret *memLSNode) { + if _, ok := m.byName[ns]; !ok { + m.byName[ns] = make(map[string]*memLSNode) + } + + if token == "" { + token = uuid.Must(uuid.NewV4()).String() + } + + walkToRoot(name, func(name0 string, first bool) bool { + n := m.byName[ns][name0] + if n == nil { + n = &memLSNode{ + details: LockDetails{ + Root: name0, + }, + childLocks: make(map[string]*memLSNode), + byExpiryIndex: -1, + } + m.byName[ns][name0] = n + } + n.refCount++ + if first { + n.token = token + ret = n + } else { + n.childLocks[token] = ret + } + return true + }) + return ret +} + +func (m *memLS) lookup(ns, name string, tokens ...string) (n *memLSNode) { + for _, token := range tokens { + n = m.byToken[token] + if n == nil || n.held { + continue + } + if n.details.Ns != ns { + continue + } + if name == n.details.Root { + return n + } + if n.details.ZeroDepth { + continue + } + if n.details.Root == "/" || strings.HasPrefix(name, n.details.Root+"/") { + return n + } + } + return nil +} + +func (m *memLS) remove(n *memLSNode) { + delete(m.byToken, n.token) + token := n.token + n.token = "" + walkToRoot(n.details.Root, func(name0 string, first bool) bool { + x := m.byName[n.details.Ns][name0] + x.refCount-- + delete(x.childLocks, token) + if x.refCount == 0 { + delete(m.byName[n.details.Ns], name0) + if len(m.byName[n.details.Ns]) == 0 { + delete(m.byName, n.details.Root) + } + } + return true + }) + if n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } +} + +func (m *memLS) hold(n *memLSNode) { + if n.held { + panic("dbfs: memLS inconsistent held state") + } + n.held = true + if n.details.Duration >= 0 && n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } +} + +func (m *memLS) unhold(n *memLSNode) { + if !n.held { + panic("dbfs: memLS inconsistent held state") + } + n.held = false + if n.details.Duration >= 0 { + heap.Push(&m.byExpiry, n) + } +} + +func walkToRoot(name string, f func(name0 string, first bool) bool) bool { + for first := true; ; first = false { + if !f(name, first) { + return false + } + if name == "/" { + break + } + name = name[:strings.LastIndex(name, "/")] + if name == "" { + name = "/" + } + } + return true +} + +type memLSNode struct { + // details are the lock metadata. Even if this node's name is not explicitly locked, + // details.Root will still equal the node's name. + details LockDetails + // token is the unique identifier for this node's lock. An empty token means that + // this node is not explicitly locked. + token string + // refCount is the number of self-or-descendent nodes that are explicitly locked. + refCount int + // expiry is when this node's lock expires. + expiry time.Time + // byExpiryIndex is the index of this node in memLS.byExpiry. It is -1 + // if this node does not expire, or has expired. + byExpiryIndex int + // held is whether this node's lock is actively held by a Confirm call. + held bool + // childLocks hold the relation between lock token and child locks. + // This is used to find out who is locking this file. + childLocks map[string]*memLSNode +} + +func (n *memLSNode) toConflictDetail(index int, hasher hashid.Encoder) *ConflictDetail { + return &ConflictDetail{ + Path: n.details.Root, + Owner: Owner{ + Application: n.details.Owner.Application, + }, + Token: n.token, + Index: index, + Type: n.details.Type, + } +} + +type byExpiry []*memLSNode + +func (b *byExpiry) Len() int { + return len(*b) +} + +func (b *byExpiry) Less(i, j int) bool { + return (*b)[i].expiry.Before((*b)[j].expiry) +} + +func (b *byExpiry) Swap(i, j int) { + (*b)[i], (*b)[j] = (*b)[j], (*b)[i] + (*b)[i].byExpiryIndex = i + (*b)[j].byExpiryIndex = j +} + +func (b *byExpiry) Push(x interface{}) { + n := x.(*memLSNode) + n.byExpiryIndex = len(*b) + *b = append(*b, n) +} + +func (b *byExpiry) Pop() interface{} { + i := len(*b) - 1 + n := (*b)[i] + (*b)[i] = nil + n.byExpiryIndex = -1 + *b = (*b)[:i] + return n +} + +// ConflictDetail represent lock conflicts that can be present to end users. +type ConflictDetail struct { + Path string `json:"path,omitempty"` + Token string `json:"token,omitempty"` + Owner Owner `json:"owner,omitempty"` + Index int `json:"-"` + Type types.FileType `json:"type"` +} + +type ConflictError []*ConflictDetail + +func (r ConflictError) Error() string { + return "conflict with locked resource: " + strings.Join( + lo.Map(r, func(item *ConflictDetail, index int) string { + return "\"" + item.Path + "\"" + }), ",") +} + +func (r ConflictError) Unwrap() error { + return ErrLocked +} diff --git a/pkg/filemanager/lock/memlock_test.go b/pkg/filemanager/lock/memlock_test.go new file mode 100644 index 00000000..37168833 --- /dev/null +++ b/pkg/filemanager/lock/memlock_test.go @@ -0,0 +1 @@ +package lock diff --git a/pkg/filemanager/manager/archive.go b/pkg/filemanager/manager/archive.go new file mode 100644 index 00000000..acd1efd8 --- /dev/null +++ b/pkg/filemanager/manager/archive.go @@ -0,0 +1,124 @@ +package manager + +import ( + "archive/zip" + "context" + "fmt" + "io" + "path" + "path/filepath" + "strings" + + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "golang.org/x/tools/container/intsets" +) + +func (m *manager) CreateArchive(ctx context.Context, uris []*fs.URI, writer io.Writer, opts ...fs.Option) (int, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + failed := 0 + + // List all top level files + files := make([]fs.File, 0, len(uris)) + for _, uri := range uris { + file, err := m.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return 0, fmt.Errorf("failed to get file %s: %w", uri, err) + } + + files = append(files, file) + } + + zipWriter := zip.NewWriter(writer) + defer zipWriter.Close() + + var compressed int64 + for _, file := range files { + if file.Type() == types.FileTypeFile { + if err := m.compressFileToArchive(ctx, "/", file, zipWriter, o.ArchiveCompression, o.DryRun); err != nil { + failed++ + m.l.Warning("Failed to compress file %s: %s, skipping it...", file.Uri(false), err) + } + + compressed += file.Size() + if o.ProgressFunc != nil { + o.ProgressFunc(compressed, file.Size(), 0) + } + + if o.MaxArchiveSize > 0 && compressed > o.MaxArchiveSize { + return 0, fs.ErrArchiveSrcSizeTooBig + } + + } else { + if err := m.Walk(ctx, file.Uri(false), intsets.MaxInt, func(f fs.File, level int) error { + if f.Type() == types.FileTypeFolder || f.IsSymbolic() { + return nil + } + if err := m.compressFileToArchive(ctx, strings.TrimPrefix(f.Uri(false).Dir(), + file.Uri(false).Dir()), f, zipWriter, o.ArchiveCompression, o.DryRun); err != nil { + failed++ + m.l.Warning("Failed to compress file %s: %s, skipping it...", f.Uri(false), err) + } + + compressed += f.Size() + if o.ProgressFunc != nil { + o.ProgressFunc(compressed, f.Size(), 0) + } + + if o.MaxArchiveSize > 0 && compressed > o.MaxArchiveSize { + return fs.ErrArchiveSrcSizeTooBig + } + + return nil + }); err != nil { + m.l.Warning("Failed to walk folder %s: %s, skipping it...", file.Uri(false), err) + failed++ + } + } + } + + return failed, nil +} + +func (m *manager) compressFileToArchive(ctx context.Context, parent string, file fs.File, zipWriter *zip.Writer, + compression bool, dryrun fs.CreateArchiveDryRunFunc) error { + es, err := m.GetEntitySource(ctx, file.PrimaryEntityID()) + if err != nil { + return fmt.Errorf("failed to get entity source for file %s: %w", file.Uri(false), err) + } + + zipName := filepath.FromSlash(path.Join(parent, file.DisplayName())) + if dryrun != nil { + dryrun(zipName, es.Entity()) + return nil + } + + m.l.Debug("Compressing %s to archive...", file.Uri(false)) + header := &zip.FileHeader{ + Name: zipName, + Modified: file.UpdatedAt(), + UncompressedSize64: uint64(file.Size()), + } + + if !compression { + header.Method = zip.Store + } else { + header.Method = zip.Deflate + } + + writer, err := zipWriter.CreateHeader(header) + if err != nil { + return fmt.Errorf("failed to create zip header for %s: %w", file.Uri(false), err) + } + + es.Apply(entitysource.WithContext(ctx)) + _, err = io.Copy(writer, es) + return err + +} diff --git a/pkg/filemanager/manager/entity.go b/pkg/filemanager/manager/entity.go new file mode 100644 index 00000000..f3f9ee0d --- /dev/null +++ b/pkg/filemanager/manager/entity.go @@ -0,0 +1,365 @@ +package manager + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "fmt" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/samber/lo" +) + +type EntityManagement interface { + // GetEntityUrls gets download urls of given entities, return URLs and the earliest expiry time + GetEntityUrls(ctx context.Context, args []GetEntityUrlArgs, opts ...fs.Option) ([]string, *time.Time, error) + // GetUrlForRedirectedDirectLink gets redirected direct download link of given direct link + GetUrlForRedirectedDirectLink(ctx context.Context, dl *ent.DirectLink, opts ...fs.Option) (string, *time.Time, error) + // GetDirectLink gets permanent direct download link of given files + GetDirectLink(ctx context.Context, urls ...*fs.URI) ([]DirectLink, error) + // GetEntitySource gets source of given entity + GetEntitySource(ctx context.Context, entityID int, opts ...fs.Option) (entitysource.EntitySource, error) + // Thumbnail gets thumbnail entity of given file + Thumbnail(ctx context.Context, uri *fs.URI) (entitysource.EntitySource, error) + // SubmitAndAwaitThumbnailTask submits a thumbnail task and waits for result + SubmitAndAwaitThumbnailTask(ctx context.Context, uri *fs.URI, ext string, entity fs.Entity) (fs.Entity, error) + // SetCurrentVersion sets current version of given file + SetCurrentVersion(ctx context.Context, path *fs.URI, version int) error + // DeleteVersion deletes a version of given file + DeleteVersion(ctx context.Context, path *fs.URI, version int) error + // ExtractAndSaveMediaMeta extracts and saves media meta into file metadata of given file. + ExtractAndSaveMediaMeta(ctx context.Context, uri *fs.URI, entityID int) error + // RecycleEntities recycles a group of entities + RecycleEntities(ctx context.Context, force bool, entityIDs ...int) error +} + +type DirectLink struct { + File fs.File + Url string +} + +func (m *manager) GetDirectLink(ctx context.Context, urls ...*fs.URI) ([]DirectLink, error) { + ae := serializer.NewAggregateError() + res := make([]DirectLink, 0, len(urls)) + useRedirect := m.user.Edges.Group.Settings.RedirectedSource + fileClient := m.dep.FileClient() + siteUrl := m.settings.SiteURL(ctx) + + for _, url := range urls { + file, err := m.fs.Get( + ctx, url, + dbfs.WithFileEntities(), + dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), + ) + if err != nil { + ae.Add(url.String(), err) + continue + } + + if file.OwnerID() != m.user.ID { + ae.Add(url.String(), fs.ErrOwnerOnly) + continue + } + + if file.Type() != types.FileTypeFile { + ae.Add(url.String(), fs.ErrEntityNotExist) + continue + } + + target := file.PrimaryEntity() + if target == nil { + ae.Add(url.String(), fs.ErrEntityNotExist) + continue + } + + // Hooks for entity download + if err := m.fs.ExecuteNavigatorHooks(ctx, fs.HookTypeBeforeDownload, file); err != nil { + m.l.Warning("Failed to execute navigator hooks: %s", err) + } + + if useRedirect { + // Use redirect source + link, err := fileClient.CreateDirectLink(ctx, file.ID(), file.Name(), m.user.Edges.Group.SpeedLimit) + if err != nil { + ae.Add(url.String(), err) + continue + } + + linkHashID := hashid.EncodeSourceLinkID(m.hasher, link.ID) + res = append(res, DirectLink{ + File: file, + Url: routes.MasterDirectLink(siteUrl, linkHashID, link.Name).String(), + }) + } else { + // Use direct source + policy, d, err := m.getEntityPolicyDriver(ctx, target, nil) + if err != nil { + ae.Add(url.String(), err) + continue + } + + source := entitysource.NewEntitySource(target, d, policy, m.auth, m.settings, m.hasher, m.dep.RequestClient(), + m.l, m.config, m.dep.MimeDetector(ctx)) + sourceUrl, err := source.Url(ctx, + entitysource.WithSpeedLimit(int64(m.user.Edges.Group.SpeedLimit)), + entitysource.WithDisplayName(file.Name()), + ) + if err != nil { + ae.Add(url.String(), err) + continue + } + + res = append(res, DirectLink{ + File: file, + Url: sourceUrl.Url, + }) + } + + } + + return res, ae.Aggregate() +} + +func (m *manager) GetUrlForRedirectedDirectLink(ctx context.Context, dl *ent.DirectLink, opts ...fs.Option) (string, *time.Time, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + file, err := dl.Edges.FileOrErr() + if err != nil { + return "", nil, err + } + + owner, err := file.Edges.OwnerOrErr() + if err != nil { + return "", nil, err + } + + entities, err := file.Edges.EntitiesOrErr() + if err != nil { + return "", nil, err + } + + // File owner must be active + if owner.Status != user.StatusActive { + return "", nil, fs.ErrDirectLinkInvalid.WithError(fmt.Errorf("file owner is not active")) + } + + // Find primary entity + target, found := lo.Find(entities, func(entity *ent.Entity) bool { + return entity.ID == file.PrimaryEntity + }) + if !found { + return "", nil, fs.ErrDirectLinkInvalid.WithError(fmt.Errorf("primary entity not found")) + } + primaryEntity := fs.NewEntity(target) + + // Generate url + var ( + res string + expire *time.Time + ) + + // Try to read from cache. + cacheKey := entityUrlCacheKey(primaryEntity.ID(), int64(dl.Speed), dl.Name, false, + m.settings.SiteURL(ctx).String()) + if cached, ok := m.kv.Get(cacheKey); ok { + cachedItem := cached.(EntityUrlCache) + res = cachedItem.Url + expire = cachedItem.ExpireAt + } else { + // Cache miss, Generate new url + policy, d, err := m.getEntityPolicyDriver(ctx, primaryEntity, nil) + if err != nil { + return "", nil, err + } + + source := entitysource.NewEntitySource(primaryEntity, d, policy, m.auth, m.settings, m.hasher, m.dep.RequestClient(), + m.l, m.config, m.dep.MimeDetector(ctx)) + downloadUrl, err := source.Url(ctx, + entitysource.WithExpire(o.Expire), + entitysource.WithDownload(false), + entitysource.WithSpeedLimit(int64(dl.Speed)), + entitysource.WithDisplayName(dl.Name), + ) + if err != nil { + return "", nil, err + } + + // Save into kv + cacheValidDuration := expireTimeToTTL(o.Expire) - m.settings.EntityUrlCacheMargin(ctx) + if cacheValidDuration > 0 { + m.kv.Set(cacheKey, EntityUrlCache{ + Url: downloadUrl.Url, + ExpireAt: downloadUrl.ExpireAt, + }, cacheValidDuration) + } + + res = downloadUrl.Url + expire = downloadUrl.ExpireAt + } + + return res, expire, nil +} + +func (m *manager) GetEntityUrls(ctx context.Context, args []GetEntityUrlArgs, opts ...fs.Option) ([]string, *time.Time, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + var earliestExpireAt *time.Time + res := make([]string, len(args)) + ae := serializer.NewAggregateError() + for i, arg := range args { + file, err := m.fs.Get( + ctx, arg.URI, + dbfs.WithFileEntities(), + dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), + ) + if err != nil { + ae.Add(arg.URI.String(), err) + continue + } + + if file.Type() != types.FileTypeFile { + ae.Add(arg.URI.String(), fs.ErrEntityNotExist) + continue + } + + var ( + target fs.Entity + found bool + ) + if arg.PreferredEntityID != "" { + found, target = fs.FindDesiredEntity(file, arg.PreferredEntityID, m.hasher, nil) + if !found { + ae.Add(arg.URI.String(), fs.ErrEntityNotExist) + continue + } + } else { + // No preferred entity ID, use the primary version entity + target = file.PrimaryEntity() + if target == nil { + ae.Add(arg.URI.String(), fs.ErrEntityNotExist) + continue + } + } + + // Hooks for entity download + if err := m.fs.ExecuteNavigatorHooks(ctx, fs.HookTypeBeforeDownload, file); err != nil { + m.l.Warning("Failed to execute navigator hooks: %s", err) + } + + // Try to read from cache. + cacheKey := entityUrlCacheKey(target.ID(), o.DownloadSpeed, getEntityDisplayName(file, target), o.IsDownload, + m.settings.SiteURL(ctx).String()) + if cached, ok := m.kv.Get(cacheKey); ok && !o.NoCache { + cachedItem := cached.(EntityUrlCache) + // Find the earliest expiry time + if cachedItem.ExpireAt != nil && (earliestExpireAt == nil || cachedItem.ExpireAt.Before(*earliestExpireAt)) { + earliestExpireAt = cachedItem.ExpireAt + } + res[i] = cachedItem.Url + continue + } + + // Cache miss, Generate new url + policy, d, err := m.getEntityPolicyDriver(ctx, target, nil) + if err != nil { + ae.Add(arg.URI.String(), err) + continue + } + + source := entitysource.NewEntitySource(target, d, policy, m.auth, m.settings, m.hasher, m.dep.RequestClient(), + m.l, m.config, m.dep.MimeDetector(ctx)) + downloadUrl, err := source.Url(ctx, + entitysource.WithExpire(o.Expire), + entitysource.WithDownload(o.IsDownload), + entitysource.WithSpeedLimit(o.DownloadSpeed), + entitysource.WithDisplayName(getEntityDisplayName(file, target)), + ) + if err != nil { + ae.Add(arg.URI.String(), err) + continue + } + + // Find the earliest expiry time + if downloadUrl.ExpireAt != nil && (earliestExpireAt == nil || downloadUrl.ExpireAt.Before(*earliestExpireAt)) { + earliestExpireAt = downloadUrl.ExpireAt + } + + // Save into kv + cacheValidDuration := expireTimeToTTL(o.Expire) - m.settings.EntityUrlCacheMargin(ctx) + if cacheValidDuration > 0 { + m.kv.Set(cacheKey, EntityUrlCache{ + Url: downloadUrl.Url, + ExpireAt: downloadUrl.ExpireAt, + }, cacheValidDuration) + } + + res[i] = downloadUrl.Url + } + + return res, earliestExpireAt, ae.Aggregate() +} + +func (m *manager) GetEntitySource(ctx context.Context, entityID int, opts ...fs.Option) (entitysource.EntitySource, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + var ( + entity fs.Entity + err error + ) + + if o.Entity != nil { + entity = o.Entity + } else { + entity, err = m.fs.GetEntity(ctx, entityID) + if err != nil { + return nil, err + } + + if entity.ReferenceCount() == 0 { + return nil, fs.ErrEntityNotExist + } + } + + policy, handler, err := m.getEntityPolicyDriver(ctx, entity, o.Policy) + if err != nil { + return nil, err + } + + return entitysource.NewEntitySource(entity, handler, policy, m.auth, m.settings, m.hasher, m.dep.RequestClient(), m.l, + m.config, m.dep.MimeDetector(ctx), entitysource.WithContext(ctx), entitysource.WithThumb(o.IsThumb)), nil +} + +func (l *manager) SetCurrentVersion(ctx context.Context, path *fs.URI, version int) error { + return l.fs.VersionControl(ctx, path, version, false) +} + +func (l *manager) DeleteVersion(ctx context.Context, path *fs.URI, version int) error { + return l.fs.VersionControl(ctx, path, version, true) +} + +func entityUrlCacheKey(id int, speed int64, displayName string, download bool, siteUrl string) string { + hash := sha1.New() + hash.Write([]byte(fmt.Sprintf("%d_%d_%s_%t_%s", id, + speed, displayName, download, siteUrl))) + hashRes := hex.EncodeToString(hash.Sum(nil)) + + return fmt.Sprintf("%s_%s", EntityUrlCacheKeyPrefix, hashRes) +} diff --git a/pkg/filemanager/manager/entitysource/entitysource.go b/pkg/filemanager/manager/entitysource/entitysource.go new file mode 100644 index 00000000..c38f9f91 --- /dev/null +++ b/pkg/filemanager/manager/entitysource/entitysource.go @@ -0,0 +1,958 @@ +package entitysource + +import ( + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httputil" + "net/textproto" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/local" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/juju/ratelimit" +) + +const ( + shortSeekBytes = 1024 + // The algorithm uses at most sniffLen bytes to make its decision. + sniffLen = 512 + defaultUrlExpire = time.Hour * 1 +) + +var ( + // ErrNoContentLength is returned by Seek when the initial http response did not include a Content-Length header + ErrNoContentLength = errors.New("Content-Length was not set") + + // errNoOverlap is returned by serveContent's parseRange if first-byte-pos of + // all of the byte-range-spec values is greater than the content size. + errNoOverlap = errors.New("invalid range: failed to overlap") +) + +type EntitySource interface { + io.ReadSeekCloser + io.ReaderAt + + // Url generates a download url for the entity. + Url(ctx context.Context, opts ...EntitySourceOption) (*EntityUrl, error) + // Serve serves the entity to the client, with supports on Range header and If- cache control. + Serve(w http.ResponseWriter, r *http.Request, opts ...EntitySourceOption) + // Entity returns the entity of the source. + Entity() fs.Entity + // IsLocal returns true if the source is in local machine. + IsLocal() bool + // LocalPath returns the local path of the source file. + LocalPath(ctx context.Context) string + // Apply applies the options to the source. + Apply(opts ...EntitySourceOption) + // CloneToLocalSrc clones the source to a local file source. + CloneToLocalSrc(t types.EntityType, src string) (EntitySource, error) + // ShouldInternalProxy returns true if the source will/should be proxied by internal proxy. + ShouldInternalProxy(opts ...EntitySourceOption) bool +} + +type EntitySourceOption interface { + Apply(any) +} + +type EntitySourceOptions struct { + SpeedLimit int64 + Expire *time.Time + IsDownload bool + NoInternalProxy bool + DisplayName string + OneTimeDownloadKey string + Ctx context.Context + IsThumb bool +} + +type EntityUrl struct { + Url string + ExpireAt *time.Time +} + +type EntitySourceOptionFunc func(any) + +// WithSpeedLimit set speed limit for file source (if supported) +func WithSpeedLimit(limit int64) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).SpeedLimit = limit + }) +} + +// WithExpire set expire time for file source +func WithExpire(expire *time.Time) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).Expire = expire + }) +} + +// WithDownload set file URL as download +func WithDownload(isDownload bool) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).IsDownload = isDownload + }) +} + +// WithNoInternalProxy overwrite policy's internal proxy setting +func WithNoInternalProxy() EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).NoInternalProxy = true + }) +} + +// WithDisplayName set display name for file source +func WithDisplayName(name string) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).DisplayName = name + }) +} + +// WithContext set context for file source +func WithContext(ctx context.Context) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).Ctx = ctx + }) +} + +// WithThumb set entity source as thumb. This will result in entity source URL +// generated with thumbnail processing parameters. For sidecar thumb files, +// this option will be ignored. +func WithThumb(isThumb bool) EntitySourceOption { + return EntitySourceOptionFunc(func(option any) { + option.(*EntitySourceOptions).IsThumb = isThumb + }) +} + +func (f EntitySourceOptionFunc) Apply(option any) { + f(option) +} + +type ( + entitySource struct { + e fs.Entity + handler driver.Handler + policy *ent.StoragePolicy + generalAuth auth.Auth + settings setting.Provider + hasher hashid.Encoder + c request.Client + l logging.Logger + config conf.ConfigProvider + mime mime.MimeDetector + + rsc io.ReadCloser + pos int64 + o *EntitySourceOptions + } +) + +// NewEntitySource creates a new EntitySource. +func NewEntitySource( + e fs.Entity, + handler driver.Handler, + policy *ent.StoragePolicy, + generalAuth auth.Auth, + settings setting.Provider, + hasher hashid.Encoder, + c request.Client, + l logging.Logger, + config conf.ConfigProvider, + mime mime.MimeDetector, + opts ...EntitySourceOption, +) EntitySource { + s := &entitySource{ + e: e, + handler: handler, + policy: policy, + generalAuth: generalAuth, + settings: settings, + hasher: hasher, + c: c, + config: config, + l: l, + mime: mime, + o: &EntitySourceOptions{}, + } + for _, opt := range opts { + opt.Apply(s.o) + } + return s +} + +func (f *entitySource) Apply(opts ...EntitySourceOption) { + for _, opt := range opts { + opt.Apply(f.o) + } +} + +func (f *entitySource) CloneToLocalSrc(t types.EntityType, src string) (EntitySource, error) { + e, err := local.NewLocalFileEntity(t, src) + if err != nil { + return nil, err + } + + policy := &ent.StoragePolicy{Type: types.PolicyTypeLocal} + handler := local.New(policy, f.l, f.config) + + newSrc := NewEntitySource(e, handler, policy, f.generalAuth, f.settings, f.hasher, f.c, f.l, f.config, f.mime).(*entitySource) + newSrc.o = f.o + return newSrc, nil +} + +func (f *entitySource) Entity() fs.Entity { + return f.e +} + +func (f *entitySource) IsLocal() bool { + return f.handler.Capabilities().StaticFeatures.Enabled(int(driver.HandlerCapabilityInboundGet)) +} + +func (f *entitySource) LocalPath(ctx context.Context) string { + return f.handler.LocalPath(ctx, f.e.Source()) +} + +func (f *entitySource) Serve(w http.ResponseWriter, r *http.Request, opts ...EntitySourceOption) { + for _, opt := range opts { + opt.Apply(f.o) + } + + if f.IsLocal() { + // For local files, validate file existence by resetting rsc + if err := f.resetRequest(); err != nil { + f.l.Warning("Failed to serve local entity %q: %s", err, f.e.Source()) + http.Error(w, "Entity data does not exist.", http.StatusNotFound) + return + } + } + + etag := "\"" + hashid.EncodeEntityID(f.hasher, f.e.ID()) + "\"" + w.Header().Set("Etag", "\""+hashid.EncodeEntityID(f.hasher, f.e.ID())+"\"") + + if f.o.IsDownload { + encodedFilename := url.PathEscape(f.o.DisplayName) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"; filename*=UTF-8''%s", + f.o.DisplayName, encodedFilename)) + } + + done, rangeReq := checkPreconditions(w, r, etag) + if done { + return + } + + if !f.IsLocal() { + // for non-local file, reverse-proxy the request + expire := time.Now().Add(defaultUrlExpire) + u, err := f.Url(driver.WithForcePublicEndpoint(f.o.Ctx, false), WithNoInternalProxy(), WithExpire(&expire)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + target, err := url.Parse(u.Url) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + start := time.Now() + proxy := &httputil.ReverseProxy{ + Director: func(request *http.Request) { + request.URL.Scheme = target.Scheme + request.URL.Host = target.Host + request.URL.Path = target.Path + request.URL.RawPath = target.RawPath + request.URL.RawQuery = target.RawQuery + request.Host = target.Host + request.Header.Del("Authorization") + }, + ModifyResponse: func(response *http.Response) error { + response.Header.Del("ETag") + response.Header.Del("Content-Disposition") + response.Header.Del("Cache-Control") + logging.Request(f.l, + false, + response.StatusCode, + response.Request.Method, + request.LocalIP, + response.Request.URL.String(), + "", + start, + ) + return nil + }, + ErrorHandler: func(writer http.ResponseWriter, request *http.Request, err error) { + f.l.Error("Reverse proxy error in %q: %s", request.URL.String(), err) + writer.WriteHeader(http.StatusBadGateway) + writer.Write([]byte("[Cloudreve] Bad Gateway")) + }, + } + + r = r.Clone(f.o.Ctx) + defer func() { + if err := recover(); err != nil && err != http.ErrAbortHandler { + panic(err) + } + }() + proxy.ServeHTTP(w, r) + return + } + + code := http.StatusOK + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = f.mime.TypeByName(f.o.DisplayName) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(f, buf[:]) + ctype = http.DetectContentType(buf[:n]) + _, err := f.Seek(0, io.SeekStart) // rewind to output whole file + if err != nil { + http.Error(w, "seeker can't seek", http.StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size := f.e.Size() + if size < 0 { + // Should never happen but just to be sure + http.Error(w, "negative content size computed", http.StatusInternalServerError) + return + } + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = f + ranges, err := parseRange(rangeReq, size) + switch err { + case nil: + case errNoOverlap: + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header and respond with a 200 rather + // than a 416. + ranges = nil + break + } + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + fallthrough + default: + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if _, err := f.Seek(ra.start, io.SeekStart); err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := f.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, f, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } +} + +func (f *entitySource) Read(p []byte) (n int, err error) { + if f.rsc == nil { + err = f.resetRequest() + } + if f.rsc != nil { + n, err = f.rsc.Read(p) + f.pos += int64(n) + } + return +} + +func (f *entitySource) ReadAt(p []byte, off int64) (n int, err error) { + if f.IsLocal() { + if f.rsc == nil { + err = f.resetRequest() + } + if readAt, ok := f.rsc.(io.ReaderAt); ok { + return readAt.ReadAt(p, off) + } + } + + return 0, errors.New("source does not support ReadAt") +} + +func (f *entitySource) Seek(offset int64, whence int) (int64, error) { + var err error + switch whence { + case io.SeekStart: + case io.SeekCurrent: + offset += f.pos + case io.SeekEnd: + offset = f.e.Size() + offset + } + if f.rsc != nil { + // Try to read, which is cheaper than doing a request + if f.pos < offset && offset-f.pos <= shortSeekBytes { + _, err := io.CopyN(io.Discard, f, offset-f.pos) + if err != nil { + return 0, err + } + } + + if f.pos != offset { + err = f.rsc.Close() + f.rsc = nil + } + } + f.pos = offset + return f.pos, err +} + +func (f *entitySource) Close() error { + if f.rsc != nil { + return f.rsc.Close() + } + return nil +} + +func (f *entitySource) ShouldInternalProxy(opts ...EntitySourceOption) bool { + for _, opt := range opts { + opt.Apply(f.o) + } + handlerCapability := f.handler.Capabilities() + return f.e.ID() == 0 || handlerCapability.StaticFeatures.Enabled(int(driver.HandlerCapabilityProxyRequired)) || + f.policy.Settings.InternalProxy && !f.o.NoInternalProxy +} + +func (f *entitySource) Url(ctx context.Context, opts ...EntitySourceOption) (*EntityUrl, error) { + for _, opt := range opts { + opt.Apply(f.o) + } + + var ( + srcUrl *url.URL + err error + srcUrlStr string + ) + + expire := f.o.Expire + displayName := f.o.DisplayName + if displayName == "" { + displayName = path.Base(util.FormSlash(f.e.Source())) + } + + // Use internal proxy URL if: + // 1. Internal proxy is required by driver's definition + // 2. Internal proxy is enabled in Policy setting and not disabled by option + // 3. It's an empty entity. + handlerCapability := f.handler.Capabilities() + if f.ShouldInternalProxy() { + siteUrl := f.settings.SiteURL(ctx) + base := routes.MasterFileContentUrl( + siteUrl, + hashid.EncodeEntityID(f.hasher, f.e.ID()), + displayName, + f.o.IsDownload, + f.o.IsThumb, + f.o.SpeedLimit, + ) + + srcUrl, err = auth.SignURI(ctx, f.generalAuth, base.String(), expire) + if err != nil { + return nil, fmt.Errorf("failed to sign internal proxy URL: %w", err) + } + + if f.IsLocal() { + // For local file, we need to apply proxy if needed + srcUrl, err = driver.ApplyProxyIfNeeded(f.policy, srcUrl) + if err != nil { + return nil, fmt.Errorf("failed to apply proxy: %w", err) + } + } + } else { + expire = capExpireTime(expire, handlerCapability.MinSourceExpire, handlerCapability.MaxSourceExpire) + if f.o.IsThumb { + srcUrlStr, err = f.handler.Thumb(ctx, expire, util.Ext(f.o.DisplayName), f.e) + } else { + srcUrlStr, err = f.handler.Source(ctx, f.e, &driver.GetSourceArgs{ + Expire: expire, + IsDownload: f.o.IsDownload, + Speed: f.o.SpeedLimit, + DisplayName: displayName, + }) + } + if err != nil { + return nil, fmt.Errorf("failed to get source URL: %w", err) + } + + srcUrl, err = url.Parse(srcUrlStr) + if err != nil { + return nil, fmt.Errorf("failed to parse origin URL: %w", err) + } + + srcUrl, err = driver.ApplyProxyIfNeeded(f.policy, srcUrl) + if err != nil { + return nil, fmt.Errorf("failed to apply proxy: %w", err) + } + } + + return &EntityUrl{ + Url: srcUrl.String(), + ExpireAt: expire, + }, nil +} + +func (f *entitySource) resetRequest() error { + // For inbound files, we can use the handler to open the file directly + if f.IsLocal() { + if f.rsc == nil { + file, err := f.handler.Open(f.o.Ctx, f.e.Source()) + if err != nil { + return fmt.Errorf("failed to open inbound file: %w", err) + } + + if f.pos > 0 { + _, err = file.Seek(f.pos, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek inbound file: %w", err) + } + } + + f.rsc = file + + if f.o.SpeedLimit > 0 { + bucket := ratelimit.NewBucketWithRate(float64(f.o.SpeedLimit), f.o.SpeedLimit) + f.rsc = lrs{f.rsc, ratelimit.Reader(f.rsc, bucket)} + } + } + + return nil + } + + expire := time.Now().Add(defaultUrlExpire) + u, err := f.Url(driver.WithForcePublicEndpoint(f.o.Ctx, false), WithNoInternalProxy(), WithExpire(&expire)) + if err != nil { + return fmt.Errorf("failed to generate download url: %w", err) + } + + h := http.Header{} + h.Set("Range", fmt.Sprintf("bytes=%d-", f.pos)) + resp := f.c.Request(http.MethodGet, u.Url, nil, + request.WithContext(f.o.Ctx), + request.WithLogger(f.l), + request.WithHeader(h), + ).CheckHTTPResponse(http.StatusOK, http.StatusPartialContent) + if resp.Err != nil { + return fmt.Errorf("failed to request download url: %w", resp.Err) + } + + f.rsc = resp.Response.Body + return nil +} + +// capExpireTime make sure expire time is not too long or too short (if min or max is set) +func capExpireTime(expire *time.Time, min, max time.Duration) *time.Time { + timeNow := time.Now() + if expire == nil { + return nil + } + + cappedExpires := *expire + // Make sure expire time is not too long or too short + if min > 0 && expire.Before(timeNow.Add(min)) { + cappedExpires = timeNow.Add(min) + } else if max > 0 && expire.After(timeNow.Add(max)) { + cappedExpires = timeNow.Add(max) + } + + return &cappedExpires +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w http.ResponseWriter, r *http.Request, etag string) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(r, etag) + if ch == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(r, etag) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + } + + rangeHeader = r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(r, etag) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(r *http.Request, currentEtag string) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, currentEtag) { + return condTrue + } + im = remain + } + + return condFalse +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +func checkIfNoneMatch(r *http.Request, currentEtag string) condResult { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, currentEtag) { + return condFalse + } + buf = remain + } + return condTrue +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} + +func checkIfRange(r *http.Request, currentEtag string) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, currentEtag) { + return condTrue + } else { + return condFalse + } + } + + return condFalse +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +type lrs struct { + c io.Closer + r io.Reader +} + +func (r lrs) Read(p []byte) (int, error) { + return r.r.Read(p) +} + +func (r lrs) Close() error { + return r.c.Close() +} diff --git a/pkg/filemanager/manager/fs.go b/pkg/filemanager/manager/fs.go new file mode 100644 index 00000000..b0e4791b --- /dev/null +++ b/pkg/filemanager/manager/fs.go @@ -0,0 +1,114 @@ +package manager + +import ( + "context" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/cos" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/local" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/obs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/onedrive" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/qiniu" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/remote" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/s3" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/upyun" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" +) + +func (m *manager) LocalDriver(policy *ent.StoragePolicy) driver.Handler { + if policy == nil { + policy = &ent.StoragePolicy{Type: types.PolicyTypeLocal, Settings: &types.PolicySetting{}} + } + return local.New(policy, m.l, m.config) +} + +func (m *manager) CastStoragePolicyOnSlave(ctx context.Context, policy *ent.StoragePolicy) *ent.StoragePolicy { + if !m.stateless { + return policy + } + + nodeId := cluster.NodeIdFromContext(ctx) + if policy.Type == types.PolicyTypeRemote { + if nodeId != policy.NodeID { + return policy + } + + policyCopy := *policy + policyCopy.Type = types.PolicyTypeLocal + return &policyCopy + } else if policy.Type == types.PolicyTypeLocal { + policyCopy := *policy + policyCopy.NodeID = nodeId + policyCopy.Type = types.PolicyTypeRemote + policyCopy.SetNode(&ent.Node{ + ID: nodeId, + Server: cluster.MasterSiteUrlFromContext(ctx), + SlaveKey: m.config.Slave().Secret, + }) + return &policyCopy + } else if policy.Type == types.PolicyTypeOss { + policyCopy := *policy + if policyCopy.Settings != nil { + policyCopy.Settings.ServerSideEndpoint = "" + } + } + + return policy +} + +func (m *manager) GetStorageDriver(ctx context.Context, policy *ent.StoragePolicy) (driver.Handler, error) { + switch policy.Type { + case types.PolicyTypeLocal: + return local.New(policy, m.l, m.config), nil + case types.PolicyTypeRemote: + return remote.New(ctx, policy, m.settings, m.config, m.l) + case types.PolicyTypeOss: + return oss.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeCos: + return cos.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeS3: + return s3.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeObs: + return obs.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeQiniu: + return qiniu.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeUpyun: + return upyun.New(ctx, policy, m.settings, m.config, m.l, m.dep.MimeDetector(ctx)) + case types.PolicyTypeOd: + return onedrive.New(ctx, policy, m.settings, m.config, m.l, m.dep.CredManager()) + default: + return nil, ErrUnknownPolicyType + } +} + +func (m *manager) getEntityPolicyDriver(cxt context.Context, e fs.Entity, policyOverwrite *ent.StoragePolicy) (*ent.StoragePolicy, driver.Handler, error) { + policyID := e.PolicyID() + var ( + policy *ent.StoragePolicy + err error + ) + if policyID == 0 { + policy = &ent.StoragePolicy{Type: types.PolicyTypeLocal, Settings: &types.PolicySetting{}} + } else { + if policyOverwrite != nil && policyOverwrite.ID == policyID { + policy = policyOverwrite + } else { + policy, err = m.policyClient.GetPolicyByID(cxt, e.PolicyID()) + if err != nil { + return nil, nil, serializer.NewError(serializer.CodeDBError, "failed to get policy", err) + } + } + } + + d, err := m.GetStorageDriver(cxt, policy) + if err != nil { + return nil, nil, err + } + + return policy, d, nil +} diff --git a/pkg/filemanager/manager/manager.go b/pkg/filemanager/manager/manager.go new file mode 100644 index 00000000..19065f36 --- /dev/null +++ b/pkg/filemanager/manager/manager.go @@ -0,0 +1,171 @@ +package manager + +import ( + "context" + "io" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var ( + ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil) +) + +const ( + UploadSessionCachePrefix = "callback_" + // Ctx key for upload session + UploadSessionCtx = "uploadSession" +) + +type ( + FileOperation interface { + // Get gets file object by given path + Get(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, error) + // List lists files under given path + List(ctx context.Context, path *fs.URI, args *ListArgs) (fs.File, *fs.ListFileResult, error) + // Create creates a file or directory + Create(ctx context.Context, path *fs.URI, fileType types.FileType, opt ...fs.Option) (fs.File, error) + // Rename renames a file or directory + Rename(ctx context.Context, path *fs.URI, newName string) (fs.File, error) + // Delete deletes a group of file or directory. UnlinkOnly indicates whether to delete file record in DB only. + Delete(ctx context.Context, path []*fs.URI, opts ...fs.Option) error + // Restore restores a group of files + Restore(ctx context.Context, path ...*fs.URI) error + // MoveOrCopy moves or copies a group of files + MoveOrCopy(ctx context.Context, src []*fs.URI, dst *fs.URI, isCopy bool) error + // Update puts file content. If given file does not exist, it will create a new one. + Update(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (fs.File, error) + // Walk walks through given path + Walk(ctx context.Context, path *fs.URI, depth int, f fs.WalkFunc, opts ...fs.Option) error + // UpsertMedata update or insert metadata of given file + PatchMedata(ctx context.Context, path []*fs.URI, data ...fs.MetadataPatch) error + // CreateViewerSession creates a viewer session for given file + CreateViewerSession(ctx context.Context, uri *fs.URI, version string, viewer *setting.Viewer) (*ViewerSession, error) + } + + FsManagement interface { + // SharedAddressTranslation translates shared symbolic address to real address. If path does not exist, + // most recent existing parent directory will be returned. + SharedAddressTranslation(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.URI, error) + // Capacity gets capacity of current file system + Capacity(ctx context.Context) (*fs.Capacity, error) + // CheckIfCapacityExceeded checks if given user's capacity exceeded, and send notification email + CheckIfCapacityExceeded(ctx context.Context) error + // LocalDriver gets local driver for operating local files. + LocalDriver(policy *ent.StoragePolicy) driver.Handler + // CastStoragePolicyOnSlave check if given storage policy need to be casted to another. + // It is used on slave node, when local policy need to cast to remote policy; + // Remote policy with same node ID can be casted to local policy. + CastStoragePolicyOnSlave(ctx context.Context, policy *ent.StoragePolicy) *ent.StoragePolicy + // GetStorageDriver gets storage driver for given policy + GetStorageDriver(ctx context.Context, policy *ent.StoragePolicy) (driver.Handler, error) + } + + ShareManagement interface { + // CreateShare creates a share link for given path + CreateOrUpdateShare(ctx context.Context, path *fs.URI, args *CreateShareArgs) (*ent.Share, error) + } + + Archiver interface { + CreateArchive(ctx context.Context, uris []*fs.URI, writer io.Writer, opts ...fs.Option) (int, error) + } + + FileManager interface { + fs.LockSystem + FileOperation + EntityManagement + UploadManagement + FsManagement + ShareManagement + Archiver + + // Recycle reset current FileManager object and put back to resource pool + Recycle() + } + + // GetEntityUrlArgs single args to get entity url + GetEntityUrlArgs struct { + URI *fs.URI + PreferredEntityID string + } + + // CreateShareArgs args to create share link + CreateShareArgs struct { + ExistedShareID int + IsPrivate bool + RemainDownloads int + Expire *time.Time + } +) + +type manager struct { + user *ent.User + l logging.Logger + fs fs.FileSystem + settings setting.Provider + kv cache.Driver + config conf.ConfigProvider + stateless bool + auth auth.Auth + hasher hashid.Encoder + policyClient inventory.StoragePolicyClient + + dep dependency.Dep +} + +func NewFileManager(dep dependency.Dep, u *ent.User) FileManager { + config := dep.ConfigProvider() + if config.System().Mode == conf.SlaveMode || u == nil { + return newStatelessFileManager(dep) + } + return &manager{ + l: dep.Logger(), + user: u, + settings: dep.SettingProvider(), + fs: dbfs.NewDatabaseFS(u, dep.FileClient(), dep.ShareClient(), dep.Logger(), dep.LockSystem(), + dep.SettingProvider(), dep.StoragePolicyClient(), dep.HashIDEncoder(), dep.UserClient(), dep.KV(), dep.NavigatorStateKV()), + kv: dep.KV(), + config: config, + auth: dep.GeneralAuth(), + hasher: dep.HashIDEncoder(), + policyClient: dep.StoragePolicyClient(), + dep: dep, + } +} + +func newStatelessFileManager(dep dependency.Dep) FileManager { + return &manager{ + l: dep.Logger(), + settings: dep.SettingProvider(), + kv: dep.KV(), + config: dep.ConfigProvider(), + stateless: true, + auth: dep.GeneralAuth(), + dep: dep, + hasher: dep.HashIDEncoder(), + } +} + +func (m *manager) Recycle() { + if m.fs != nil { + m.fs.Recycle() + } +} + +func newOption() *fs.FsOption { + return &fs.FsOption{} +} diff --git a/pkg/filemanager/manager/mediameta.go b/pkg/filemanager/manager/mediameta.go new file mode 100644 index 00000000..4ed3d252 --- /dev/null +++ b/pkg/filemanager/manager/mediameta.go @@ -0,0 +1,193 @@ +package manager + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/samber/lo" +) + +type ( + MediaMetaTask struct { + *queue.DBTask + } + + MediaMetaTaskState struct { + Uri *fs.URI `json:"uri"` + EntityID int `json:"entity_id"` + } +) + +func init() { + queue.RegisterResumableTaskFactory(queue.MediaMetaTaskType, NewMediaMetaTaskFromModel) +} + +// NewMediaMetaTask creates a new MediaMetaTask to +func NewMediaMetaTask(ctx context.Context, uri *fs.URI, entityID int, creator *ent.User) (*MediaMetaTask, error) { + state := &MediaMetaTaskState{ + Uri: uri, + EntityID: entityID, + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + return &MediaMetaTask{ + DBTask: &queue.DBTask{ + DirectOwner: creator, + Task: &ent.Task{ + Type: queue.MediaMetaTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{}, + }, + }, + }, nil +} + +func NewMediaMetaTaskFromModel(task *ent.Task) queue.Task { + return &MediaMetaTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func (m *MediaMetaTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + fm := NewFileManager(dep, inventory.UserFromContext(ctx)).(*manager) + + // unmarshal state + var state MediaMetaTaskState + if err := json.Unmarshal([]byte(m.State()), &state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %s (%w)", err, queue.CriticalErr) + } + + err := fm.ExtractAndSaveMediaMeta(ctx, state.Uri, state.EntityID) + if err != nil { + return task.StatusError, err + } + + return task.StatusCompleted, nil +} + +func (m *manager) ExtractAndSaveMediaMeta(ctx context.Context, uri *fs.URI, entityID int) error { + // 1. retrieve file info + file, err := m.fs.Get(ctx, uri, dbfs.WithFileEntities()) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + versions := lo.Filter(file.Entities(), func(i fs.Entity, index int) bool { + return i.Type() == types.EntityTypeVersion + }) + targetVersion, versionIndex, found := lo.FindIndexOf(versions, func(i fs.Entity) bool { + return i.ID() == entityID + }) + if !found { + return fmt.Errorf("failed to find version: %s (%w)", err, queue.CriticalErr) + } + + if versionIndex != 0 { + m.l.Debug("Skip media meta task for non-latest version.") + return nil + } + + var ( + metas []driver.MediaMeta + ) + // 2. try using native driver + _, d, err := m.getEntityPolicyDriver(ctx, targetVersion, nil) + if err != nil { + return fmt.Errorf("failed to get storage driver: %s (%w)", err, queue.CriticalErr) + } + driverCaps := d.Capabilities() + if util.IsInExtensionList(driverCaps.MediaMetaSupportedExts, file.Name()) { + m.l.Debug("Using native driver to generate media meta.") + metas, err = d.MediaMeta(ctx, targetVersion.Source(), file.Ext()) + if err != nil { + return fmt.Errorf("failed to get media meta using native driver: %w", err) + } + } else if driverCaps.MediaMetaProxy && util.IsInExtensionList(m.dep.MediaMetaExtractor(ctx).Exts(), file.Name()) { + m.l.Debug("Using local extractor to generate media meta.") + extractor := m.dep.MediaMetaExtractor(ctx) + source, err := m.GetEntitySource(ctx, targetVersion.ID()) + defer source.Close() + if err != nil { + return fmt.Errorf("failed to get entity source: %w", err) + } + + metas, err = extractor.Extract(ctx, file.Ext(), source) + if err != nil { + return fmt.Errorf("failed to extract media meta using local extractor: %w", err) + } + + } else { + m.l.Debug("No available generator for media meta.") + return nil + } + + m.l.Debug("%d media meta generated.", len(metas)) + m.l.Debug("Media meta: %v", metas) + + // 3. save meta + if len(metas) > 0 { + if err := m.fs.PatchMetadata(ctx, []*fs.URI{uri}, lo.Map(metas, func(i driver.MediaMeta, index int) fs.MetadataPatch { + return fs.MetadataPatch{ + Key: fmt.Sprintf("%s:%s", i.Type, i.Key), + Value: i.Value, + } + })...); err != nil { + return fmt.Errorf("failed to save media meta: %s (%w)", err, queue.CriticalErr) + } + } + + return nil +} + +func (m *manager) shouldGenerateMediaMeta(ctx context.Context, d driver.Handler, fileName string) bool { + driverCaps := d.Capabilities() + if util.IsInExtensionList(driverCaps.MediaMetaSupportedExts, fileName) { + // Handler support it natively + return true + } + + if driverCaps.MediaMetaProxy && util.IsInExtensionList(m.dep.MediaMetaExtractor(ctx).Exts(), fileName) { + // Handler does not support. but proxy is enabled. + return true + } + + return false +} + +func (m *manager) mediaMetaForNewEntity(ctx context.Context, session *fs.UploadSession, d driver.Handler) { + if session.Props.EntityType == nil || *session.Props.EntityType == types.EntityTypeVersion { + if !m.shouldGenerateMediaMeta(ctx, d, session.Props.Uri.Name()) { + return + } + + mediaMetaTask, err := NewMediaMetaTask(ctx, session.Props.Uri, session.EntityID, m.user) + if err != nil { + m.l.Warning("Failed to create media meta task: %s", err) + return + } + if err := m.dep.MediaMetaQueue(ctx).QueueTask(ctx, mediaMetaTask); err != nil { + m.l.Warning("Failed to queue media meta task: %s", err) + } + + return + } +} diff --git a/pkg/filemanager/manager/metadata.go b/pkg/filemanager/manager/metadata.go new file mode 100644 index 00000000..b6ae4600 --- /dev/null +++ b/pkg/filemanager/manager/metadata.go @@ -0,0 +1,174 @@ +package manager + +import ( + "context" + "crypto/sha1" + "encoding/json" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/go-playground/validator/v10" + "strings" +) + +type ( + metadataValidator func(ctx context.Context, m *manager, patch *fs.MetadataPatch) error +) + +const ( + wildcardMetadataKey = "*" + customizeMetadataSuffix = "customize" + tagMetadataSuffix = "tag" + iconColorMetadataKey = customizeMetadataSuffix + ":icon_color" + emojiIconMetadataKey = customizeMetadataSuffix + ":emoji" + shareOwnerMetadataKey = dbfs.MetadataSysPrefix + "shared_owner" + shareRedirectMetadataKey = dbfs.MetadataSysPrefix + "shared_redirect" +) + +var ( + validate = validator.New() + + lastEmojiHash = "" + emojiPresets = map[string]struct{}{} + + // validateColor validates a color value + validateColor = func(optional bool) metadataValidator { + return func(ctx context.Context, m *manager, patch *fs.MetadataPatch) error { + if patch.Remove { + return nil + } + + tag := "omitempty,iscolor" + if !optional { + tag = "required,iscolor" + } + + res := validate.Var(patch.Value, tag) + if res != nil { + return fmt.Errorf("invalid color: %w", res) + } + + return nil + } + } + validators = map[string]map[string]metadataValidator{ + "sys": { + wildcardMetadataKey: func(ctx context.Context, m *manager, patch *fs.MetadataPatch) error { + if patch.Remove { + return fmt.Errorf("cannot remove system metadata") + } + + dep := dependency.FromContext(ctx) + // Validate share owner is valid hashid + if patch.Key == shareOwnerMetadataKey { + hasher := dep.HashIDEncoder() + _, err := hasher.Decode(patch.Value, hashid.UserID) + if err != nil { + return fmt.Errorf("invalid share owner: %w", err) + } + + return nil + } + + // Validate share redirect uri is valid share uri + if patch.Key == shareRedirectMetadataKey { + uri, err := fs.NewUriFromString(patch.Value) + if err != nil || uri.FileSystem() != constants.FileSystemShare { + return fmt.Errorf("invalid redirect uri: %w", err) + } + + return nil + } + + return fmt.Errorf("unsupported system metadata key: %s", patch.Key) + }, + }, + "dav": {}, + customizeMetadataSuffix: { + iconColorMetadataKey: validateColor(false), + emojiIconMetadataKey: func(ctx context.Context, m *manager, patch *fs.MetadataPatch) error { + if patch.Remove { + return nil + } + + // Validate if patched emoji is within preset list. + emojis := m.settings.EmojiPresets(ctx) + current := fmt.Sprintf("%x", (sha1.Sum([]byte(emojis)))) + if current != lastEmojiHash { + presets := make(map[string][]string) + if err := json.Unmarshal([]byte(emojis), &presets); err != nil { + return fmt.Errorf("failed to read emoji setting: %w", err) + } + + emojiPresets = make(map[string]struct{}) + for _, v := range presets { + for _, emoji := range v { + emojiPresets[emoji] = struct{}{} + } + } + } + + if _, ok := emojiPresets[patch.Value]; !ok { + return fmt.Errorf("unsupported emoji") + } + return nil + }, + }, + tagMetadataSuffix: { + wildcardMetadataKey: func(ctx context.Context, m *manager, patch *fs.MetadataPatch) error { + if err := validateColor(true)(ctx, m, patch); err != nil { + return err + } + + if patch.Key == tagMetadataSuffix+":" { + return fmt.Errorf("invalid metadata key") + } + + return nil + }, + }, + } +) + +func (m *manager) PatchMedata(ctx context.Context, path []*fs.URI, data ...fs.MetadataPatch) error { + if err := m.validateMetadata(ctx, data...); err != nil { + return err + } + + return m.fs.PatchMetadata(ctx, path, data...) +} + +func (m *manager) validateMetadata(ctx context.Context, data ...fs.MetadataPatch) error { + for _, patch := range data { + category := strings.Split(patch.Key, ":") + if len(category) < 2 { + return serializer.NewError(serializer.CodeParamErr, "Invalid metadata key", nil) + } + + categoryValidators, ok := validators[category[0]] + if !ok { + return serializer.NewError(serializer.CodeParamErr, "Invalid metadata key", + fmt.Errorf("unknown category: %s", category[0])) + } + + // Explicit validators + if v, ok := categoryValidators[patch.Key]; ok { + if err := v(ctx, m, &patch); err != nil { + return serializer.NewError(serializer.CodeParamErr, "Invalid metadata patch", err) + } + } + + // Wildcard validators + if v, ok := categoryValidators[wildcardMetadataKey]; ok { + if err := v(ctx, m, &patch); err != nil { + return serializer.NewError(serializer.CodeParamErr, "Invalid metadata patch", err) + } + } + } + + return nil +} diff --git a/pkg/filemanager/manager/operation.go b/pkg/filemanager/manager/operation.go new file mode 100644 index 00000000..ffadfdcf --- /dev/null +++ b/pkg/filemanager/manager/operation.go @@ -0,0 +1,296 @@ +package manager + +import ( + "context" + "encoding/gob" + "fmt" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/samber/lo" +) + +const ( + EntityUrlCacheKeyPrefix = "entity_url_" + DownloadSentinelCachePrefix = "download_sentinel_" +) + +type ( + ListArgs struct { + Page int + PageSize int + PageToken string + Order string + OrderDirection string + // StreamResponseCallback is used for streamed list operation, e.g. searching files. + // Whenever a new item is found, this callback will be called with the current item and the parent item. + StreamResponseCallback func(fs.File, []fs.File) + } + + EntityUrlCache struct { + Url string + ExpireAt *time.Time + } +) + +func init() { + gob.Register(EntityUrlCache{}) +} + +func (m *manager) Get(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, error) { + return m.fs.Get(ctx, path, opts...) +} + +func (m *manager) List(ctx context.Context, path *fs.URI, args *ListArgs) (fs.File, *fs.ListFileResult, error) { + dbfsSetting := m.settings.DBFS(ctx) + opts := []fs.Option{ + fs.WithPageSize(args.PageSize), + fs.WithOrderBy(args.Order), + fs.WithOrderDirection(args.OrderDirection), + dbfs.WithFilePublicMetadata(), + dbfs.WithContextHint(), + dbfs.WithFileShareIfOwned(), + } + + searchParams := path.SearchParameters() + if searchParams != nil { + if dbfsSetting.UseSSEForSearch { + opts = append(opts, dbfs.WithStreamListResponseCallback(args.StreamResponseCallback)) + } + + if searchParams.Category != "" { + // Overwrite search query with predefined category + category := fs.SearchCategoryFromString(searchParams.Category) + if category == setting.CategoryUnknown { + return nil, nil, fmt.Errorf("unknown category: %s", searchParams.Category) + } + + path = path.SetQuery(m.settings.SearchCategoryQuery(ctx, category)) + searchParams = path.SearchParameters() + } + } + + if dbfsSetting.UseCursorPagination || searchParams != nil { + opts = append(opts, dbfs.WithCursorPagination(args.PageToken)) + } else { + opts = append(opts, fs.WithPage(args.Page)) + } + + return m.fs.List(ctx, path, opts...) +} + +func (m *manager) SharedAddressTranslation(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.URI, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + return m.fs.SharedAddressTranslation(ctx, path) +} + +func (m *manager) Create(ctx context.Context, path *fs.URI, fileType types.FileType, opts ...fs.Option) (fs.File, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + if m.stateless { + return nil, o.Node.CreateFile(ctx, &fs.StatelessCreateFileService{ + Path: path.String(), + Type: fileType, + UserID: o.StatelessUserID, + }) + } + + isSymbolic := false + if o.Metadata != nil { + if err := m.validateMetadata(ctx, lo.MapToSlice(o.Metadata, func(key string, value string) fs.MetadataPatch { + if key == shareRedirectMetadataKey { + isSymbolic = true + } + + return fs.MetadataPatch{ + Key: key, + Value: value, + } + })...); err != nil { + return nil, err + } + } + + if isSymbolic { + opts = append(opts, dbfs.WithSymbolicLink()) + } + + return m.fs.Create(ctx, path, fileType, opts...) +} + +func (m *manager) Rename(ctx context.Context, path *fs.URI, newName string) (fs.File, error) { + return m.fs.Rename(ctx, path, newName) +} + +func (m *manager) MoveOrCopy(ctx context.Context, src []*fs.URI, dst *fs.URI, isCopy bool) error { + return m.fs.MoveOrCopy(ctx, src, dst, isCopy) +} + +func (m *manager) SoftDelete(ctx context.Context, path ...*fs.URI) error { + return m.fs.SoftDelete(ctx, path...) +} + +func (m *manager) Delete(ctx context.Context, path []*fs.URI, opts ...fs.Option) error { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + if !o.SkipSoftDelete && !o.SysSkipSoftDelete { + return m.SoftDelete(ctx, path...) + } + + staleEntities, err := m.fs.Delete(ctx, path, fs.WithUnlinkOnly(o.UnlinkOnly), fs.WithSysSkipSoftDelete(o.SysSkipSoftDelete)) + if err != nil { + return err + } + + m.l.Debug("New stale entities: %v", staleEntities) + + // Delete stale entities + if len(staleEntities) > 0 { + t, err := newExplicitEntityRecycleTask(ctx, lo.Map(staleEntities, func(entity fs.Entity, index int) int { + return entity.ID() + })) + if err != nil { + return fmt.Errorf("failed to create explicit entity recycle task: %w", err) + } + + if err := m.dep.EntityRecycleQueue(ctx).QueueTask(ctx, t); err != nil { + return fmt.Errorf("failed to queue explicit entity recycle task: %w", err) + } + } + return nil +} + +func (m *manager) Walk(ctx context.Context, path *fs.URI, depth int, f fs.WalkFunc, opts ...fs.Option) error { + return m.fs.Walk(ctx, path, depth, f, opts...) +} + +func (m *manager) Capacity(ctx context.Context) (*fs.Capacity, error) { + res, err := m.fs.Capacity(ctx, m.user) + if err != nil { + return nil, err + } + + return res, nil +} + +func (m *manager) CheckIfCapacityExceeded(ctx context.Context) error { + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + capacity, err := m.Capacity(ctx) + if err != nil { + return fmt.Errorf("failed to get user capacity: %w", err) + } + + if capacity.Used <= capacity.Total { + return nil + } + + return nil +} + +func (l *manager) ConfirmLock(ctx context.Context, ancestor fs.File, uri *fs.URI, token ...string) (func(), fs.LockSession, error) { + return l.fs.ConfirmLock(ctx, ancestor, uri, token...) +} + +func (l *manager) Lock(ctx context.Context, d time.Duration, requester *ent.User, zeroDepth bool, + application lock.Application, uri *fs.URI, token string) (fs.LockSession, error) { + return l.fs.Lock(ctx, d, requester, zeroDepth, application, uri, token) +} + +func (l *manager) Unlock(ctx context.Context, tokens ...string) error { + return l.fs.Unlock(ctx, tokens...) +} + +func (l *manager) Refresh(ctx context.Context, d time.Duration, token string) (lock.LockDetails, error) { + return l.fs.Refresh(ctx, d, token) +} + +func (l *manager) Restore(ctx context.Context, path ...*fs.URI) error { + return l.fs.Restore(ctx, path...) +} + +func (l *manager) CreateOrUpdateShare(ctx context.Context, path *fs.URI, args *CreateShareArgs) (*ent.Share, error) { + file, err := l.fs.Get(ctx, path, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityShare)) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "src file not found", err) + } + + // Only file owner can share file + if file.OwnerID() != l.user.ID { + return nil, serializer.NewError(serializer.CodeNoPermissionErr, "permission denied", nil) + } + + if file.IsSymbolic() { + return nil, serializer.NewError(serializer.CodeNoPermissionErr, "cannot share symbolic file", nil) + } + + var existed *ent.Share + shareClient := l.dep.ShareClient() + if args.ExistedShareID != 0 { + loadShareCtx := context.WithValue(ctx, inventory.LoadShareFile{}, true) + existed, err = shareClient.GetByID(loadShareCtx, args.ExistedShareID) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "failed to get existed share", err) + } + + if existed.Edges.File.ID != file.ID() { + return nil, serializer.NewError(serializer.CodeNotFound, "share link not found", nil) + } + } + + password := "" + if args.IsPrivate { + password = util.RandString(8, util.RandomLowerCases) + } + + share, err := shareClient.Upsert(ctx, &inventory.CreateShareParams{ + OwnerID: file.OwnerID(), + FileID: file.ID(), + Password: password, + Expires: args.Expire, + RemainDownloads: args.RemainDownloads, + Existed: existed, + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "failed to create share", err) + } + + return share, nil +} + +func getEntityDisplayName(f fs.File, e fs.Entity) string { + switch e.Type() { + case types.EntityTypeThumbnail: + return fmt.Sprintf("%s_thumbnail", f.DisplayName()) + case types.EntityTypeLivePhoto: + return fmt.Sprintf("%s_live_photo.mov", f.DisplayName()) + default: + return f.Name() + } +} + +func expireTimeToTTL(expireAt *time.Time) int { + if expireAt == nil { + return -1 + } + + return int(time.Until(*expireAt).Seconds()) +} diff --git a/pkg/filemanager/manager/recycle.go b/pkg/filemanager/manager/recycle.go new file mode 100644 index 00000000..6213e350 --- /dev/null +++ b/pkg/filemanager/manager/recycle.go @@ -0,0 +1,374 @@ +package manager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/crontab" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" +) + +type ( + ExplicitEntityRecycleTask struct { + *queue.DBTask + } + + ExplicitEntityRecycleTaskState struct { + EntityIDs []int `json:"entity_ids,omitempty"` + Errors [][]RecycleError `json:"errors,omitempty"` + } + + RecycleError struct { + ID string `json:"id"` + Error string `json:"error"` + } +) + +func init() { + queue.RegisterResumableTaskFactory(queue.ExplicitEntityRecycleTaskType, NewExplicitEntityRecycleTaskFromModel) + queue.RegisterResumableTaskFactory(queue.EntityRecycleRoutineTaskType, NewEntityRecycleRoutineTaskFromModel) + crontab.Register(setting.CronTypeEntityCollect, func(ctx context.Context) { + dep := dependency.FromContext(ctx) + l := dep.Logger() + t, err := NewEntityRecycleRoutineTask(ctx) + if err != nil { + l.Error("Failed to create entity recycle routine task: %s", err) + } + + if err := dep.EntityRecycleQueue(ctx).QueueTask(ctx, t); err != nil { + l.Error("Failed to queue entity recycle routine task: %s", err) + } + }) + crontab.Register(setting.CronTypeTrashBinCollect, CronCollectTrashBin) +} + +func NewExplicitEntityRecycleTaskFromModel(task *ent.Task) queue.Task { + return &ExplicitEntityRecycleTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func newExplicitEntityRecycleTask(ctx context.Context, entities []int) (*ExplicitEntityRecycleTask, error) { + state := &ExplicitEntityRecycleTaskState{ + EntityIDs: entities, + Errors: make([][]RecycleError, 0), + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + t := &ExplicitEntityRecycleTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.ExplicitEntityRecycleTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{ + ResumeTime: time.Now().Unix() - 1, + }, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func (m *ExplicitEntityRecycleTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + fm := NewFileManager(dep, inventory.UserFromContext(ctx)).(*manager) + + // unmarshal state + state := &ExplicitEntityRecycleTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + // recycle entities + err := fm.RecycleEntities(ctx, false, state.EntityIDs...) + if err != nil { + appendAe(&state.Errors, err) + privateState, err := json.Marshal(state) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", err) + } + m.Task.PrivateState = string(privateState) + return task.StatusError, err + } + + return task.StatusCompleted, nil +} + +type ( + EntityRecycleRoutineTask struct { + *queue.DBTask + } + + EntityRecycleRoutineTaskState struct { + Errors [][]RecycleError `json:"errors,omitempty"` + } +) + +func NewEntityRecycleRoutineTaskFromModel(task *ent.Task) queue.Task { + return &EntityRecycleRoutineTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func NewEntityRecycleRoutineTask(ctx context.Context) (queue.Task, error) { + state := &EntityRecycleRoutineTaskState{ + Errors: make([][]RecycleError, 0), + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + t := &EntityRecycleRoutineTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.EntityRecycleRoutineTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{ + ResumeTime: time.Now().Unix() - 1, + }, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func (m *EntityRecycleRoutineTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + fm := NewFileManager(dep, inventory.UserFromContext(ctx)).(*manager) + + // unmarshal state + state := &EntityRecycleRoutineTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + // recycle entities + err := fm.RecycleEntities(ctx, false) + if err != nil { + appendAe(&state.Errors, err) + + privateState, err := json.Marshal(state) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", err) + } + m.Task.PrivateState = string(privateState) + return task.StatusError, err + } + + return task.StatusCompleted, nil +} + +// RecycleEntities delete given entities. If the ID list is empty, it will walk through +// all stale entities in DB. +func (m *manager) RecycleEntities(ctx context.Context, force bool, entityIDs ...int) error { + ae := serializer.NewAggregateError() + entities, err := m.fs.StaleEntities(ctx, entityIDs...) + if err != nil { + return fmt.Errorf("failed to get entities: %w", err) + } + + // Group entities by policy ID + entityGroup := lo.GroupBy(entities, func(entity fs.Entity) int { + return entity.PolicyID() + }) + + // Delete entity in each group in batch + for _, entities := range entityGroup { + entityChunk := lo.Chunk(entities, 100) + m.l.Info("Recycling %d entities in %d batches", len(entities), len(entityChunk)) + + for batch, chunk := range entityChunk { + m.l.Info("Start to recycle batch #%d, %d entities", batch, len(chunk)) + mapSrcToId := make(map[string]int, len(chunk)) + _, d, err := m.getEntityPolicyDriver(ctx, chunk[0], nil) + if err != nil { + for _, entity := range chunk { + ae.Add(strconv.Itoa(entity.ID()), err) + } + continue + } + + for _, entity := range chunk { + mapSrcToId[entity.Source()] = entity.ID() + } + + res, err := d.Delete(ctx, lo.Map(chunk, func(entity fs.Entity, index int) string { + return entity.Source() + })...) + if err != nil { + for _, src := range res { + ae.Add(strconv.Itoa(mapSrcToId[src]), err) + } + } + + // Delete upload session if it's still valid + for _, entity := range chunk { + sid := entity.UploadSessionID() + if sid == nil { + continue + } + + if session, ok := m.kv.Get(UploadSessionCachePrefix + sid.String()); ok { + session := session.(fs.UploadSession) + if err := d.CancelToken(ctx, &session); err != nil { + m.l.Warning("Failed to cancel upload session for %q: %s, this is expected if it's remote policy.", session.Props.Uri.String(), err) + } + _ = m.kv.Delete(UploadSessionCachePrefix, sid.String()) + } + } + + // Filtering out entities that are successfully deleted + rawAe := ae.Raw() + successEntities := lo.FilterMap(chunk, func(entity fs.Entity, index int) (int, bool) { + entityIdStr := fmt.Sprintf("%d", entity.ID()) + _, ok := rawAe[entityIdStr] + if !ok { + // No error, deleted + return entity.ID(), true + } + + if force { + ae.Remove(entityIdStr) + } + return entity.ID(), force + }) + + // Remove entities from DB + fc, tx, ctx, err := inventory.WithTx(ctx, m.dep.FileClient()) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + storageReduced, err := fc.RemoveEntitiesByID(ctx, successEntities...) + if err != nil { + _ = inventory.Rollback(tx) + return fmt.Errorf("failed to remove entities from DB: %w", err) + } + + tx.AppendStorageDiff(storageReduced) + if err := inventory.CommitWithStorageDiff(ctx, tx, m.l, m.dep.UserClient()); err != nil { + return fmt.Errorf("failed to commit delete change: %w", err) + } + + } + } + + return ae.Aggregate() +} + +const ( + MinimumTrashCollectBatch = 1000 +) + +// CronCollectTrashBin walks through all files in trash bin and delete them if they are expired. +func CronCollectTrashBin(ctx context.Context) { + dep := dependency.FromContext(ctx) + l := dep.Logger() + fm := NewFileManager(dep, inventory.UserFromContext(ctx)).(*manager) + pageSize := dep.SettingProvider().DBFS(ctx).MaxPageSize + batch := 0 + expiredFiles := make([]fs.File, 0) + for { + res, err := fm.fs.AllFilesInTrashBin(ctx, fs.WithPageSize(pageSize)) + if err != nil { + l.Error("Failed to get files in trash bin: %s", err) + } + + expired := lo.Filter(res.Files, func(file fs.File, index int) bool { + if expire, ok := file.Metadata()[dbfs.MetadataExpectedCollectTime]; ok { + expireUnix, err := strconv.ParseInt(expire, 10, 64) + if err != nil { + l.Warning("Failed to parse expected collect time %q: %s, will treat as expired", expire, err) + } + + if expireUnix < time.Now().Unix() { + return true + } + } + + return false + }) + l.Info("Found %d files in trash bin pending collect, in batch #%d", len(res.Files), batch) + + expiredFiles = append(expiredFiles, expired...) + if len(expiredFiles) >= MinimumTrashCollectBatch { + collectTrashBin(ctx, expiredFiles, dep, l) + expiredFiles = expiredFiles[:0] + } + + if res.Pagination.NextPageToken == "" { + if len(expiredFiles) > 0 { + collectTrashBin(ctx, expiredFiles, dep, l) + } + break + } + + batch++ + } +} + +func collectTrashBin(ctx context.Context, files []fs.File, dep dependency.Dep, l logging.Logger) { + l.Info("Start to collect %d files in trash bin", len(files)) + uc := dep.UserClient() + + // Group files by Owners + fileGroup := lo.GroupBy(files, func(file fs.File) int { + return file.OwnerID() + }) + + for uid, expiredFiles := range fileGroup { + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + user, err := uc.GetByID(ctx, uid) + if err != nil { + l.Error("Failed to get user %d: %s", uid, err) + continue + } + + ctx = context.WithValue(ctx, inventory.UserCtx{}, user) + fm := NewFileManager(dep, user).(*manager) + if err := fm.Delete(ctx, lo.Map(expiredFiles, func(file fs.File, index int) *fs.URI { + return file.Uri(false) + }), fs.WithSkipSoftDelete(true)); err != nil { + l.Error("Failed to delete files for user %d: %s", uid, err) + } + } +} + +func appendAe(errs *[][]RecycleError, err error) { + var ae *serializer.AggregateError + *errs = append(*errs, make([]RecycleError, 0)) + if errors.As(err, &ae) { + (*errs)[len(*errs)-1] = lo.MapToSlice(ae.Raw(), func(key string, value error) RecycleError { + return RecycleError{ + ID: key, + Error: value.Error(), + } + }) + } +} diff --git a/pkg/filemanager/manager/thumbnail.go b/pkg/filemanager/manager/thumbnail.go new file mode 100644 index 00000000..83ca98ce --- /dev/null +++ b/pkg/filemanager/manager/thumbnail.go @@ -0,0 +1,294 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "os" + "runtime" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/local" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/samber/lo" +) + +// Thumbnail returns the thumbnail entity of the file. +func (m *manager) Thumbnail(ctx context.Context, uri *fs.URI) (entitysource.EntitySource, error) { + // retrieve file info + file, err := m.fs.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithFilePublicMetadata()) + if err != nil { + return nil, fmt.Errorf("failed to get file: %w", err) + } + + // 0. Check if thumb is disabled in this file. + if _, ok := file.Metadata()[dbfs.ThumbDisabledKey]; ok || file.Type() != types.FileTypeFile { + return nil, fs.ErrEntityNotExist + } + + // 1. If thumbnail entity exist, use it. + entities := file.Entities() + thumbEntity, found := lo.Find(entities, func(e fs.Entity) bool { + return e.Type() == types.EntityTypeThumbnail + }) + if found { + thumbSource, err := m.GetEntitySource(ctx, 0, fs.WithEntity(thumbEntity)) + if err != nil { + return nil, fmt.Errorf("failed to get entity source: %w", err) + } + + thumbSource.Apply(entitysource.WithDisplayName(file.DisplayName() + ".jpg")) + return thumbSource, nil + } + + latest := file.PrimaryEntity() + // If primary entity not exist, or it's empty + if latest == nil || latest.ID() == 0 { + return nil, fmt.Errorf("failed to get latest version") + } + + // 2. Thumb entity not exist, try native policy generator + _, handler, err := m.getEntityPolicyDriver(ctx, latest, nil) + if err != nil { + return nil, fmt.Errorf("failed to get entity policy driver: %w", err) + } + capabilities := handler.Capabilities() + // Check if file extension and size is supported by native policy generator. + if capabilities.ThumbSupportAllExts || util.IsInExtensionList(capabilities.ThumbSupportedExts, file.DisplayName()) && + (capabilities.ThumbMaxSize == 0 || latest.Size() <= capabilities.ThumbMaxSize) { + thumbSource, err := m.GetEntitySource(ctx, 0, fs.WithEntity(latest), fs.WithUseThumb(true)) + if err != nil { + return nil, fmt.Errorf("failed to get latest entity source: %w", err) + } + + thumbSource.Apply(entitysource.WithDisplayName(file.DisplayName())) + return thumbSource, nil + } else if capabilities.ThumbProxy { + if err := m.fs.CheckCapability(ctx, uri, + dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityGenerateThumb)); err != nil { + // Current FS does not support generate new thumb. + return nil, fs.ErrEntityNotExist + } + + thumbEntity, err := m.SubmitAndAwaitThumbnailTask(ctx, uri, file.Ext(), latest) + if err != nil { + return nil, fmt.Errorf("failed to execute thumb task: %w", err) + } + + thumbSource, err := m.GetEntitySource(ctx, 0, fs.WithEntity(thumbEntity)) + if err != nil { + return nil, fmt.Errorf("failed to get entity source: %w", err) + } + + return thumbSource, nil + } else { + // 4. If proxy generator not support, mark thumb as not available. + _ = disableThumb(ctx, m, uri) + } + + return nil, fs.ErrEntityNotExist +} + +func (m *manager) SubmitAndAwaitThumbnailTask(ctx context.Context, uri *fs.URI, ext string, entity fs.Entity) (fs.Entity, error) { + es, err := m.GetEntitySource(ctx, 0, fs.WithEntity(entity)) + if err != nil { + return nil, fmt.Errorf("failed to get entity source: %w", err) + } + + defer es.Close() + t := newGenerateThumbTask(ctx, m, uri, ext, es) + if err := m.dep.ThumbQueue(ctx).QueueTask(ctx, t); err != nil { + return nil, fmt.Errorf("failed to queue task: %w", err) + } + + // Wait for task to finish + select { + case <-ctx.Done(): + return nil, ctx.Err() + case res := <-t.sig: + if res.err != nil { + return nil, fmt.Errorf("failed to generate thumb: %w", res.err) + } + + return res.thumbEntity, nil + } + +} + +func (m *manager) generateThumb(ctx context.Context, uri *fs.URI, ext string, es entitysource.EntitySource) (fs.Entity, error) { + // Generate thumb + pipeline := m.dep.ThumbPipeline() + res, err := pipeline.Generate(ctx, es, ext, nil) + if err != nil { + if res != nil && res.Path != "" { + _ = os.Remove(res.Path) + } + + if !errors.Is(err, context.Canceled) && !m.stateless { + if err := disableThumb(ctx, m, uri); err != nil { + m.l.Warning("Failed to disable thumb: %v", err) + } + } + + return nil, fmt.Errorf("failed to generate thumb: %w", err) + } + + defer os.Remove(res.Path) + + // Upload thumb entity + thumbFile, err := os.Open(res.Path) + if err != nil { + return nil, fmt.Errorf("failed to open temp thumb %q: %w", res.Path, err) + } + + defer thumbFile.Close() + fileInfo, err := thumbFile.Stat() + if err != nil { + return nil, fmt.Errorf("failed to stat temp thumb %q: %w", res.Path, err) + } + + var ( + thumbEntity fs.Entity + ) + if m.stateless { + _, d, err := m.getEntityPolicyDriver(ctx, es.Entity(), nil) + if err != nil { + return nil, fmt.Errorf("failed to get storage driver: %w", err) + } + + savePath := es.Entity().Source() + m.settings.ThumbSlaveSidecarSuffix(ctx) + if err := d.Put(ctx, &fs.UploadRequest{ + File: thumbFile, + Seeker: thumbFile, + Props: &fs.UploadProps{SavePath: savePath}, + }); err != nil { + return nil, fmt.Errorf("failed to save thumb sidecar: %w", err) + } + + thumbEntity, err = local.NewLocalFileEntity(types.EntityTypeThumbnail, savePath) + if err != nil { + return nil, fmt.Errorf("failed to create local thumb entity: %w", err) + } + } else { + entityType := types.EntityTypeThumbnail + req := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: uri, + Size: fileInfo.Size(), + SavePath: fmt.Sprintf( + "%s.%s%s", + es.Entity().Source(), + util.RandStringRunes(16), + m.settings.ThumbEntitySuffix(ctx), + ), + MimeType: m.dep.MimeDetector(ctx).TypeByName("thumb.jpg"), + EntityType: &entityType, + }, + File: thumbFile, + Seeker: thumbFile, + } + + // Generating thumb can be triggered by users with read-only permission. We can bypass update permission check. + ctx = dbfs.WithBypassOwnerCheck(ctx) + + file, err := m.Update(ctx, req, fs.WithEntityType(types.EntityTypeThumbnail)) + if err != nil { + return nil, fmt.Errorf("failed to upload thumb entity: %w", err) + } + + entities := file.Entities() + found := false + thumbEntity, found = lo.Find(entities, func(e fs.Entity) bool { + return e.Type() == types.EntityTypeThumbnail + }) + if !found { + return nil, fmt.Errorf("failed to find thumb entity") + } + + } + + if m.settings.ThumbGCAfterGen(ctx) { + m.l.Debug("GC after thumb generation") + runtime.GC() + } + + return thumbEntity, nil +} + +type ( + GenerateThumbTask struct { + *queue.InMemoryTask + es entitysource.EntitySource + ext string + m *manager + uri *fs.URI + sig chan *generateRes + } + generateRes struct { + thumbEntity fs.Entity + err error + } +) + +func newGenerateThumbTask(ctx context.Context, m *manager, uri *fs.URI, ext string, es entitysource.EntitySource) *GenerateThumbTask { + t := &GenerateThumbTask{ + InMemoryTask: &queue.InMemoryTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + CorrelationID: logging.CorrelationID(ctx), + PublicState: &types.TaskPublicState{}, + }, + }, + }, + es: es, + ext: ext, + m: m, + uri: uri, + sig: make(chan *generateRes, 2), + } + + t.InMemoryTask.DBTask.Task.SetUser(m.user) + return t +} + +func (m *GenerateThumbTask) Do(ctx context.Context) (task.Status, error) { + var ( + res fs.Entity + err error + ) + defer func() { m.sig <- &generateRes{res, err} }() + + // Make sure user does not cancel request before we start generating thumb. + select { + case <-ctx.Done(): + err = ctx.Err() + return task.StatusError, err + default: + } + + res, err = m.m.generateThumb(ctx, m.uri, m.ext, m.es) + return task.StatusCompleted, nil +} + +func (m *GenerateThumbTask) OnError(err error, d time.Duration) { + m.InMemoryTask.OnError(err, d) + m.sig <- &generateRes{nil, err} +} + +func disableThumb(ctx context.Context, m *manager, uri *fs.URI) error { + return m.fs.PatchMetadata( + dbfs.WithBypassOwnerCheck(ctx), + []*fs.URI{uri}, fs.MetadataPatch{ + Key: dbfs.ThumbDisabledKey, + Value: "", + Private: false, + }) +} diff --git a/pkg/filemanager/manager/upload.go b/pkg/filemanager/manager/upload.go new file mode 100644 index 00000000..6a4a6955 --- /dev/null +++ b/pkg/filemanager/manager/upload.go @@ -0,0 +1,500 @@ +package manager + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +type ( + UploadManagement interface { + // CreateUploadSession creates a upload session for given upload request + CreateUploadSession(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (*fs.UploadCredential, error) + // ConfirmUploadSession confirms whether upload session is valid for upload. + ConfirmUploadSession(ctx context.Context, session *fs.UploadSession, chunkIndex int) (fs.File, error) + // Upload uploads file data to storage + Upload(ctx context.Context, req *fs.UploadRequest, policy *ent.StoragePolicy) error + // CompleteUpload completes upload session and returns file object + CompleteUpload(ctx context.Context, session *fs.UploadSession) (fs.File, error) + // CancelUploadSession cancels upload session + CancelUploadSession(ctx context.Context, path *fs.URI, sessionID string) error + // OnUploadFailed should be called when an unmanaged upload failed before complete. + OnUploadFailed(ctx context.Context, session *fs.UploadSession) + // Similar to CompleteUpload, but does not create actual uplaod session in storage. + PrepareUpload(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (*fs.UploadSession, error) + } +) + +func (m *manager) CreateUploadSession(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (*fs.UploadCredential, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + + // Validate metadata + if req.Props.Metadata != nil { + if err := m.validateMetadata(ctx, lo.MapToSlice(req.Props.Metadata, func(key string, value string) fs.MetadataPatch { + return fs.MetadataPatch{ + Key: key, + Value: value, + } + })...); err != nil { + return nil, err + } + } + + uploadSession := o.UploadSession + var ( + err error + ) + + if uploadSession == nil { + // If upload session not specified, invoke DBFS to create one + sessionID := uuid.Must(uuid.NewV4()).String() + req.Props.UploadSessionID = sessionID + ttl := m.settings.UploadSessionTTL(ctx) + req.Props.ExpireAt = time.Now().Add(ttl) + + // Prepare for upload + uploadSession, err = m.fs.PrepareUpload(ctx, req) + if err != nil { + return nil, fmt.Errorf("faield to prepare uplaod: %w", err) + } + } + + d, err := m.GetStorageDriver(ctx, m.CastStoragePolicyOnSlave(ctx, uploadSession.Policy)) + if err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, err + } + + uploadSession.ChunkSize = uploadSession.Policy.Settings.ChunkSize + // Create upload credential for underlying storage driver + credential := &fs.UploadCredential{} + if !uploadSession.Policy.Settings.Relay || m.stateless { + credential, err = d.Token(ctx, uploadSession, req) + if err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, err + } + } else { + // For relayed upload, we don't need to create credential + uploadSession.ChunkSize = 0 + credential.ChunkSize = 0 + } + credential.SessionID = uploadSession.Props.UploadSessionID + credential.Expires = req.Props.ExpireAt.Unix() + credential.StoragePolicy = uploadSession.Policy + credential.CallbackSecret = uploadSession.CallbackSecret + credential.Uri = uploadSession.Props.Uri.String() + + // If upload sentinel check is required, queue a check task + if d.Capabilities().StaticFeatures.Enabled(int(driver.HandlerCapabilityUploadSentinelRequired)) { + t, err := newUploadSentinelCheckTask(ctx, uploadSession) + if err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, fmt.Errorf("failed to create upload sentinel check task: %w", err) + } + + if err := m.dep.EntityRecycleQueue(ctx).QueueTask(ctx, t); err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, fmt.Errorf("failed to queue upload sentinel check task: %w", err) + } + + uploadSession.SentinelTaskID = t.ID() + } + + err = m.kv.Set( + UploadSessionCachePrefix+req.Props.UploadSessionID, + *uploadSession, + max(1, int(req.Props.ExpireAt.Sub(time.Now()).Seconds())), + ) + if err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, err + } + + return credential, nil +} + +func (m *manager) ConfirmUploadSession(ctx context.Context, session *fs.UploadSession, chunkIndex int) (fs.File, error) { + // Get placeholder file + file, err := m.fs.Get(ctx, session.Props.Uri) + if err != nil { + return nil, fmt.Errorf("failed to get placeholder file: %w", err) + } + + // Confirm locks on placeholder file + if session.LockToken == "" { + release, ls, err := m.fs.ConfirmLock(ctx, file, file.Uri(false), session.LockToken) + if err != nil { + return nil, fs.ErrLockExpired.WithError(err) + } + + defer release() + ctx = fs.LockSessionToContext(ctx, ls) + } + + // Make sure this storage policy is OK to receive data from clients to Cloudreve server. + if session.Policy.Type != types.PolicyTypeLocal && !session.Policy.Settings.Relay { + return nil, serializer.NewError(serializer.CodePolicyNotAllowed, "", nil) + } + + actualSizeStart := int64(chunkIndex) * session.ChunkSize + if session.Policy.Settings.ChunkSize == 0 && chunkIndex > 0 { + return nil, serializer.NewError(serializer.CodeInvalidChunkIndex, "Chunk index cannot be greater than 0", nil) + } + + if actualSizeStart > 0 && actualSizeStart >= session.Props.Size { + return nil, serializer.NewError(serializer.CodeInvalidChunkIndex, "Chunk offset cannot be greater than file size", nil) + } + + return file, nil +} + +func (m *manager) PrepareUpload(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (*fs.UploadSession, error) { + return m.fs.PrepareUpload(ctx, req, opts...) +} + +func (m *manager) Upload(ctx context.Context, req *fs.UploadRequest, policy *ent.StoragePolicy) error { + d, err := m.GetStorageDriver(ctx, m.CastStoragePolicyOnSlave(ctx, policy)) + if err != nil { + return err + } + + if err := d.Put(ctx, req); err != nil { + return serializer.NewError(serializer.CodeIOFailed, "Failed to upload file", err) + } + + return nil +} + +func (m *manager) CancelUploadSession(ctx context.Context, path *fs.URI, sessionID string) error { + // Get upload session + var session *fs.UploadSession + sessionRaw, ok := m.kv.Get(UploadSessionCachePrefix + sessionID) + if ok { + sessionTyped := sessionRaw.(fs.UploadSession) + session = &sessionTyped + } + + var ( + staleEntities []fs.Entity + err error + ) + + if !m.stateless { + staleEntities, err = m.fs.CancelUploadSession(ctx, path, sessionID, session) + if err != nil { + return err + } + + m.l.Debug("New stale entities: %v", staleEntities) + } + + if session != nil { + ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, strconv.Itoa(session.Policy.NodeID)) + d, err := m.GetStorageDriver(ctx, m.CastStoragePolicyOnSlave(ctx, session.Policy)) + if err != nil { + return fmt.Errorf("failed to get storage driver: %w", err) + } + + if m.stateless { + if _, err = d.Delete(ctx, session.Props.SavePath); err != nil { + return fmt.Errorf("failed to delete file: %w", err) + } + } else { + if err = d.CancelToken(ctx, session); err != nil { + return fmt.Errorf("failed to cancel upload session: %w", err) + } + } + + m.kv.Delete(UploadSessionCachePrefix, session.Props.UploadSessionID) + } + + // Delete stale entities + if len(staleEntities) > 0 { + t, err := newExplicitEntityRecycleTask(ctx, lo.Map(staleEntities, func(entity fs.Entity, index int) int { + return entity.ID() + })) + if err != nil { + return fmt.Errorf("failed to create explicit entity recycle task: %w", err) + } + + if err := m.dep.EntityRecycleQueue(ctx).QueueTask(ctx, t); err != nil { + return fmt.Errorf("failed to queue explicit entity recycle task: %w", err) + } + } + + return nil +} + +func (m *manager) CompleteUpload(ctx context.Context, session *fs.UploadSession) (fs.File, error) { + d, err := m.GetStorageDriver(ctx, m.CastStoragePolicyOnSlave(ctx, session.Policy)) + if err != nil { + return nil, err + } + + if err := d.CompleteUpload(ctx, session); err != nil { + return nil, err + } + + var ( + file fs.File + ) + if m.fs != nil { + file, err = m.fs.CompleteUpload(ctx, session) + if err != nil { + return nil, fmt.Errorf("failed to complete upload: %w", err) + } + } + + if session.SentinelTaskID > 0 { + // Cancel sentinel check task + m.l.Debug("Cancel upload sentinel check task [%d].", session.SentinelTaskID) + if err := m.dep.TaskClient().SetCompleteByID(ctx, session.SentinelTaskID); err != nil { + m.l.Warning("Failed to set upload sentinel check task [%d] to complete: %s", session.SentinelTaskID, err) + } + } + + m.onNewEntityUploaded(ctx, session, d) + // Remove upload session + _ = m.kv.Delete(UploadSessionCachePrefix, session.Props.UploadSessionID) + return file, nil +} + +func (m *manager) Update(ctx context.Context, req *fs.UploadRequest, opts ...fs.Option) (fs.File, error) { + o := newOption() + for _, opt := range opts { + opt.Apply(o) + } + entityType := types.EntityTypeVersion + if o.EntityType != nil { + entityType = *o.EntityType + } + + req.Props.EntityType = &entityType + if o.EntityTypeNil { + req.Props.EntityType = nil + } + + req.Props.UploadSessionID = uuid.Must(uuid.NewV4()).String() + + if m.stateless { + return m.updateStateless(ctx, req, o) + } + + // Prepare for upload + uploadSession, err := m.fs.PrepareUpload(ctx, req) + if err != nil { + return nil, fmt.Errorf("faield to prepare uplaod: %w", err) + } + + if err := m.Upload(ctx, req, uploadSession.Policy); err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, fmt.Errorf("failed to upload new entity: %w", err) + } + + file, err := m.CompleteUpload(ctx, uploadSession) + if err != nil { + m.OnUploadFailed(ctx, uploadSession) + return nil, fmt.Errorf("failed to complete update: %w", err) + } + + return file, nil +} + +func (m *manager) OnUploadFailed(ctx context.Context, session *fs.UploadSession) { + ctx = context.WithoutCancel(ctx) + if !m.stateless { + if session.LockToken != "" { + if err := m.Unlock(ctx, session.LockToken); err != nil { + m.l.Warning("OnUploadFailed hook failed to unlock: %s", err) + } + } + + if session.NewFileCreated { + if err := m.Delete(ctx, []*fs.URI{session.Props.Uri}, fs.WithSysSkipSoftDelete(true)); err != nil { + m.l.Warning("OnUploadFailed hook failed to delete file: %s", err) + } + } else { + if err := m.fs.VersionControl(ctx, session.Props.Uri, session.EntityID, true); err != nil { + m.l.Warning("OnUploadFailed hook failed to version control: %s", err) + } + } + } else { + d, err := m.GetStorageDriver(ctx, m.CastStoragePolicyOnSlave(ctx, session.Policy)) + if err != nil { + m.l.Warning("OnUploadFailed hook failed: %s", err) + } + + if failed, err := d.Delete(ctx, session.Props.SavePath); err != nil { + m.l.Warning("OnUploadFailed hook failed to remove uploaded file: %s, failed file: %v", err, failed) + } + } +} + +// similar to Update, but expected to be executed on slave node. +func (m *manager) updateStateless(ctx context.Context, req *fs.UploadRequest, o *fs.FsOption) (fs.File, error) { + // Prepare for upload + res, err := o.Node.PrepareUpload(ctx, &fs.StatelessPrepareUploadService{ + UploadRequest: req, + UserID: o.StatelessUserID, + }) + if err != nil { + return nil, fmt.Errorf("faield to prepare uplaod: %w", err) + } + + req.Props = res.Req.Props + if err := m.Upload(ctx, req, res.Session.Policy); err != nil { + if err := o.Node.OnUploadFailed(ctx, &fs.StatelessOnUploadFailedService{ + UploadSession: res.Session, + UserID: o.StatelessUserID, + }); err != nil { + m.l.Warning("Failed to call stateless OnUploadFailed: %s", err) + } + return nil, fmt.Errorf("failed to upload new entity: %w", err) + } + + err = o.Node.CompleteUpload(ctx, &fs.StatelessCompleteUploadService{ + UploadSession: res.Session, + UserID: o.StatelessUserID, + }) + if err != nil { + if err := o.Node.OnUploadFailed(ctx, &fs.StatelessOnUploadFailedService{ + UploadSession: res.Session, + UserID: o.StatelessUserID, + }); err != nil { + m.l.Warning("Failed to call stateless OnUploadFailed: %s", err) + } + return nil, fmt.Errorf("failed to complete update: %w", err) + } + + return nil, nil +} + +func (m *manager) onNewEntityUploaded(ctx context.Context, session *fs.UploadSession, d driver.Handler) { + if !m.stateless { + // Submit media meta task for new entity + m.mediaMetaForNewEntity(ctx, session, d) + } +} + +// Upload sentinel check task is used for compliant storage policy (COS, S3...), it will delete the marked entity. +// It is expected to be queued after upload session is created, and canceled after upload callback is completed. +// If this task is executed, it means the upload callback does not complete in time. +type ( + UploadSentinelCheckTask struct { + *queue.DBTask + } + UploadSentinelCheckTaskState struct { + Session *fs.UploadSession `json:"session"` + } +) + +const ( + uploadSentinelCheckMargin = 5 * time.Minute +) + +func init() { + queue.RegisterResumableTaskFactory(queue.UploadSentinelCheckTaskType, NewUploadSentinelCheckTaskFromModel) +} + +func NewUploadSentinelCheckTaskFromModel(task *ent.Task) queue.Task { + return &UploadSentinelCheckTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func newUploadSentinelCheckTask(ctx context.Context, uploadSession *fs.UploadSession) (*ExplicitEntityRecycleTask, error) { + state := &UploadSentinelCheckTaskState{ + Session: uploadSession, + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + resumeAfter := uploadSession.Props.ExpireAt.Add(uploadSentinelCheckMargin) + t := &ExplicitEntityRecycleTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.UploadSentinelCheckTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{ + ResumeTime: resumeAfter.Unix(), + }, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func (m *UploadSentinelCheckTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + taskClient := dep.TaskClient() + l := dep.Logger() + fm := NewFileManager(dep, inventory.UserFromContext(ctx)).(*manager) + + // Check if sentinel is canceled due to callback complete + t, err := taskClient.GetTaskByID(ctx, m.ID()) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get task by ID: %w", err) + } + + if t.Status == task.StatusCompleted { + l.Info("Upload sentinel check task [%d] is canceled due to callback complete.", m.ID()) + return task.StatusCompleted, nil + } + + // unmarshal state + state := &UploadSentinelCheckTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + l.Info("Upload sentinel check triggered, clean up stale place holder entity [%d].", state.Session.EntityID) + entity, err := fm.fs.GetEntity(ctx, state.Session.EntityID) + if err != nil { + l.Debug("Failed to get entity [%d]: %s, skip sentinel check.", state.Session.EntityID, err) + return task.StatusCompleted, nil + } + + _, d, err := fm.getEntityPolicyDriver(ctx, entity, nil) + if err != nil { + l.Debug("Failed to get storage driver for entity [%d]: %s", state.Session.EntityID, err) + return task.StatusError, err + } + + _, err = d.Delete(ctx, entity.Source()) + if err != nil { + l.Debug("Failed to delete entity source [%d]: %s", state.Session.EntityID, err) + return task.StatusError, err + } + + if err := d.CancelToken(ctx, state.Session); err != nil { + l.Debug("Failed to cancel token [%d]: %s", state.Session.EntityID, err) + } + + return task.StatusCompleted, nil +} diff --git a/pkg/filemanager/manager/viewer.go b/pkg/filemanager/manager/viewer.go new file mode 100644 index 00000000..0b85f8c4 --- /dev/null +++ b/pkg/filemanager/manager/viewer.go @@ -0,0 +1,93 @@ +package manager + +import ( + "context" + "encoding/gob" + "fmt" + "time" + + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" +) + +type ( + ViewerSession struct { + ID string `json:"id"` + AccessToken string `json:"access_token"` + Expires int64 `json:"expires"` + File fs.File `json:"-"` + } + ViewerSessionCache struct { + ID string + Uri string + UserID int + FileID int + ViewerID string + Version string + Token string + } + ViewerSessionCacheCtx struct{} + ViewerCtx struct{} +) + +const ( + ViewerSessionCachePrefix = "viewer_session_" + + sessionExpiresPadding = 10 +) + +func init() { + gob.Register(ViewerSessionCache{}) +} + +func (m *manager) CreateViewerSession(ctx context.Context, uri *fs.URI, version string, viewer *setting.Viewer) (*ViewerSession, error) { + file, err := m.fs.Get(ctx, uri, dbfs.WithFileEntities()) + if err != nil { + return nil, err + } + + versionType := types.EntityTypeVersion + found, desired := fs.FindDesiredEntity(file, version, m.hasher, &versionType) + if !found { + return nil, fs.ErrEntityNotExist + } + + if desired.Size() > m.settings.MaxOnlineEditSize(ctx) { + return nil, fs.ErrFileSizeTooBig + } + + sessionID := uuid.Must(uuid.NewV4()).String() + token := util.RandStringRunes(128) + sessionCache := &ViewerSessionCache{ + ID: sessionID, + Uri: file.Uri(false).String(), + UserID: m.user.ID, + ViewerID: viewer.ID, + FileID: file.ID(), + Version: version, + Token: fmt.Sprintf("%s.%s", sessionID, token), + } + ttl := m.settings.ViewerSessionTTL(ctx) + if err := m.kv.Set(ViewerSessionCachePrefix+sessionID, *sessionCache, ttl); err != nil { + return nil, err + } + + return &ViewerSession{ + File: file, + ID: sessionID, + AccessToken: sessionCache.Token, + Expires: time.Now().Add(time.Duration(ttl-sessionExpiresPadding) * time.Second).UnixMilli(), + }, nil +} + +func ViewerSessionFromContext(ctx context.Context) *ViewerSessionCache { + return ctx.Value(ViewerSessionCacheCtx{}).(*ViewerSessionCache) +} + +func ViewerFromContext(ctx context.Context) *setting.Viewer { + return ctx.Value(ViewerCtx{}).(*setting.Viewer) +} diff --git a/pkg/filemanager/workflows/archive.go b/pkg/filemanager/workflows/archive.go new file mode 100644 index 00000000..bcef9bb2 --- /dev/null +++ b/pkg/filemanager/workflows/archive.go @@ -0,0 +1,682 @@ +package workflows + +import ( + "archive/zip" + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" +) + +type ( + CreateArchiveTask struct { + *queue.DBTask + + l logging.Logger + state *CreateArchiveTaskState + progress queue.Progresses + node cluster.Node + } + + CreateArchiveTaskPhase string + + CreateArchiveTaskState struct { + Uris []string `json:"uris,omitempty"` + Dst string `json:"dst,omitempty"` + TempPath string `json:"temp_path,omitempty"` + ArchiveFile string `json:"archive_file,omitempty"` + Phase CreateArchiveTaskPhase `json:"phase,omitempty"` + SlaveUploadTaskID int `json:"slave__upload_task_id,omitempty"` + SlaveArchiveTaskID int `json:"slave__archive_task_id,omitempty"` + SlaveCompressState *SlaveCreateArchiveTaskState `json:"slave_compress_state,omitempty"` + Failed int `json:"failed,omitempty"` + NodeState `json:",inline"` + } +) + +const ( + CreateArchiveTaskPhaseNotStarted CreateArchiveTaskPhase = "not_started" + CreateArchiveTaskPhaseCompressFiles CreateArchiveTaskPhase = "compress_files" + CreateArchiveTaskPhaseUploadArchive CreateArchiveTaskPhase = "upload_archive" + + CreateArchiveTaskPhaseAwaitSlaveCompressing CreateArchiveTaskPhase = "await_slave_compressing" + CreateArchiveTaskPhaseCreateAndAwaitSlaveUploading CreateArchiveTaskPhase = "await_slave_uploading" + CreateArchiveTaskPhaseCompleteUpload CreateArchiveTaskPhase = "complete_upload" + + ProgressTypeArchiveCount = "archive_count" + ProgressTypeArchiveSize = "archive_size" + ProgressTypeUpload = "upload" + ProgressTypeUploadCount = "upload_count" +) + +func init() { + queue.RegisterResumableTaskFactory(queue.CreateArchiveTaskType, NewCreateArchiveTaskFromModel) +} + +// NewCreateArchiveTask creates a new CreateArchiveTask +func NewCreateArchiveTask(ctx context.Context, src []string, dst string) (queue.Task, error) { + state := &CreateArchiveTaskState{ + Uris: src, + Dst: dst, + NodeState: NodeState{}, + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + t := &CreateArchiveTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.CreateArchiveTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{}, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func NewCreateArchiveTaskFromModel(task *ent.Task) queue.Task { + return &CreateArchiveTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func (m *CreateArchiveTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + m.l = dep.Logger() + + m.Lock() + if m.progress == nil { + m.progress = make(queue.Progresses) + } + m.Unlock() + + // unmarshal state + state := &CreateArchiveTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + m.state = state + + // select node + node, err := allocateNode(ctx, dep, &m.state.NodeState, types.NodeCapabilityCreateArchive) + if err != nil { + return task.StatusError, fmt.Errorf("failed to allocate node: %w", err) + } + m.node = node + + next := task.StatusCompleted + + if m.node.IsMaster() { + // Initialize temp folder + // Compress files + // Upload files to dst + switch m.state.Phase { + case CreateArchiveTaskPhaseNotStarted, "": + next, err = m.initializeTempFolder(ctx, dep) + case CreateArchiveTaskPhaseCompressFiles: + next, err = m.createArchiveFile(ctx, dep) + case CreateArchiveTaskPhaseUploadArchive: + next, err = m.uploadArchive(ctx, dep) + default: + next, err = task.StatusError, fmt.Errorf("unknown phase %q: %w", m.state.Phase, queue.CriticalErr) + } + } else { + // Listing all files and send to slave node for compressing + // Await compressing and send to slave for uploading + // Await uploading and complete upload + switch m.state.Phase { + case CreateArchiveTaskPhaseNotStarted, "": + next, err = m.listEntitiesAndSendToSlave(ctx, dep) + case CreateArchiveTaskPhaseAwaitSlaveCompressing: + next, err = m.awaitSlaveCompressing(ctx, dep) + case CreateArchiveTaskPhaseCreateAndAwaitSlaveUploading: + next, err = m.createAndAwaitSlaveUploading(ctx, dep) + case CreateArchiveTaskPhaseCompleteUpload: + next, err = m.completeUpload(ctx, dep) + default: + next, err = task.StatusError, fmt.Errorf("unknown phase %q: %w", m.state.Phase, queue.CriticalErr) + } + } + + newStateStr, marshalErr := json.Marshal(m.state) + if marshalErr != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr) + } + + m.Lock() + m.Task.PrivateState = string(newStateStr) + m.Unlock() + return next, err +} + +func (m *CreateArchiveTask) Cleanup(ctx context.Context) error { + if m.state.SlaveCompressState != nil && m.state.SlaveCompressState.TempPath != "" && m.node != nil { + if err := m.node.CleanupFolders(context.Background(), m.state.SlaveCompressState.TempPath); err != nil { + m.l.Warning("Failed to cleanup slave temp folder %s: %s", m.state.SlaveCompressState.TempPath, err) + } + } + + if m.state.TempPath != "" { + time.Sleep(time.Duration(1) * time.Second) + return os.RemoveAll(m.state.TempPath) + } + + return nil +} + +func (m *CreateArchiveTask) initializeTempFolder(ctx context.Context, dep dependency.Dep) (task.Status, error) { + tempPath, err := prepareTempFolder(ctx, dep, m) + if err != nil { + return task.StatusError, fmt.Errorf("failed to prepare temp folder: %w", err) + } + + m.state.TempPath = tempPath + m.state.Phase = CreateArchiveTaskPhaseCompressFiles + m.ResumeAfter(0) + return task.StatusSuspending, nil +} + +func (m *CreateArchiveTask) listEntitiesAndSendToSlave(ctx context.Context, dep dependency.Dep) (task.Status, error) { + uris, err := fs.NewUriFromStrings(m.state.Uris...) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create uri from strings: %s (%w)", err, queue.CriticalErr) + } + + payload := &SlaveCreateArchiveTaskState{ + Entities: make([]SlaveCreateArchiveEntity, 0, len(uris)), + Policies: make(map[int]*ent.StoragePolicy), + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + storagePolicyClient := dep.StoragePolicyClient() + + failed, err := fm.CreateArchive(ctx, uris, io.Discard, + fs.WithDryRun(func(name string, e fs.Entity) { + payload.Entities = append(payload.Entities, SlaveCreateArchiveEntity{ + Entity: e.Model(), + Path: name, + }) + if _, ok := payload.Policies[e.PolicyID()]; !ok { + policy, err := storagePolicyClient.GetPolicyByID(ctx, e.PolicyID()) + if err != nil { + m.l.Warning("Failed to get policy %d: %s", e.PolicyID(), err) + } else { + payload.Policies[e.PolicyID()] = policy + } + } + }), + fs.WithMaxArchiveSize(user.Edges.Group.Settings.CompressSize), + ) + if err != nil { + return task.StatusError, fmt.Errorf("failed to compress files: %w", err) + } + + m.state.Failed = failed + payloadStr, err := json.Marshal(payload) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal payload: %w", err) + } + + taskId, err := m.node.CreateTask(ctx, queue.SlaveCreateArchiveTaskType, string(payloadStr)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create slave task: %w", err) + } + + m.state.Phase = CreateArchiveTaskPhaseAwaitSlaveCompressing + m.state.SlaveArchiveTaskID = taskId + m.ResumeAfter((10 * time.Second)) + return task.StatusSuspending, nil +} + +func (m *CreateArchiveTask) awaitSlaveCompressing(ctx context.Context, dep dependency.Dep) (task.Status, error) { + t, err := m.node.GetTask(ctx, m.state.SlaveArchiveTaskID, false) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get slave task: %w", err) + } + + m.Lock() + m.state.NodeState.progress = t.Progress + m.Unlock() + + m.state.SlaveCompressState = &SlaveCreateArchiveTaskState{} + if err := json.Unmarshal([]byte(t.PrivateState), m.state.SlaveCompressState); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal slave compress state: %s (%w)", err, queue.CriticalErr) + } + + if t.Status == task.StatusError { + return task.StatusError, fmt.Errorf("slave task failed: %s (%w)", t.Error, queue.CriticalErr) + } + + if t.Status == task.StatusCanceled { + return task.StatusError, fmt.Errorf("slave task canceled (%w)", queue.CriticalErr) + } + + if t.Status == task.StatusCompleted { + m.state.Phase = CreateArchiveTaskPhaseCreateAndAwaitSlaveUploading + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + m.l.Info("Slave task %d is still compressing, resume after 30s.", m.state.SlaveArchiveTaskID) + m.ResumeAfter((time.Second * 30)) + return task.StatusSuspending, nil +} + +func (m *CreateArchiveTask) createAndAwaitSlaveUploading(ctx context.Context, dep dependency.Dep) (task.Status, error) { + u := inventory.UserFromContext(ctx) + + if m.state.SlaveUploadTaskID == 0 { + dst, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse dst uri %q: %s (%w)", m.state.Dst, err, queue.CriticalErr) + } + + // Create slave upload task + payload := &SlaveUploadTaskState{ + Files: []SlaveUploadEntity{ + { + Size: m.state.SlaveCompressState.CompressedSize, + Uri: dst, + Src: m.state.SlaveCompressState.ZipFilePath, + }, + }, + MaxParallel: dep.SettingProvider().MaxParallelTransfer(ctx), + UserID: u.ID, + } + + payloadStr, err := json.Marshal(payload) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal payload: %w", err) + } + + taskId, err := m.node.CreateTask(ctx, queue.SlaveUploadTaskType, string(payloadStr)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create slave task: %w", err) + } + + m.state.NodeState.progress = nil + m.state.SlaveUploadTaskID = taskId + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + m.l.Info("Checking slave upload task %d...", m.state.SlaveUploadTaskID) + t, err := m.node.GetTask(ctx, m.state.SlaveUploadTaskID, true) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get slave task: %w", err) + } + + m.Lock() + m.state.NodeState.progress = t.Progress + m.Unlock() + + if t.Status == task.StatusError { + return task.StatusError, fmt.Errorf("slave task failed: %s (%w)", t.Error, queue.CriticalErr) + } + + if t.Status == task.StatusCanceled { + return task.StatusError, fmt.Errorf("slave task canceled (%w)", queue.CriticalErr) + } + + if t.Status == task.StatusCompleted { + m.state.Phase = CreateArchiveTaskPhaseCompleteUpload + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + m.l.Info("Slave task %d is still uploading, resume after 30s.", m.state.SlaveUploadTaskID) + m.ResumeAfter(time.Second * 30) + return task.StatusSuspending, nil +} + +func (m *CreateArchiveTask) completeUpload(ctx context.Context, dep dependency.Dep) (task.Status, error) { + return task.StatusCompleted, nil +} + +func (m *CreateArchiveTask) createArchiveFile(ctx context.Context, dep dependency.Dep) (task.Status, error) { + uris, err := fs.NewUriFromStrings(m.state.Uris...) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create uri from strings: %s (%w)", err, queue.CriticalErr) + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + + // Create temp zip file + fileName := fmt.Sprintf("%s.zip", uuid.Must(uuid.NewV4())) + zipFilePath := filepath.Join( + m.state.TempPath, + fileName, + ) + zipFile, err := util.CreatNestedFile(zipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create zip file: %w", err) + } + + defer zipFile.Close() + + // Start compressing + m.Lock() + m.progress[ProgressTypeArchiveCount] = &queue.Progress{} + m.progress[ProgressTypeArchiveSize] = &queue.Progress{} + m.Unlock() + failed, err := fm.CreateArchive(ctx, uris, zipFile, + fs.WithArchiveCompression(true), + fs.WithMaxArchiveSize(user.Edges.Group.Settings.CompressSize), + fs.WithProgressFunc(func(current, diff int64, total int64) { + atomic.AddInt64(&m.progress[ProgressTypeArchiveSize].Current, diff) + atomic.AddInt64(&m.progress[ProgressTypeArchiveCount].Current, 1) + }), + ) + if err != nil { + zipFile.Close() + _ = os.Remove(zipFilePath) + return task.StatusError, fmt.Errorf("failed to compress files: %w", err) + } + + m.state.Failed = failed + m.Lock() + delete(m.progress, ProgressTypeArchiveSize) + delete(m.progress, ProgressTypeArchiveCount) + m.Unlock() + + m.state.Phase = CreateArchiveTaskPhaseUploadArchive + m.state.ArchiveFile = fileName + m.ResumeAfter(0) + return task.StatusSuspending, nil +} + +func (m *CreateArchiveTask) uploadArchive(ctx context.Context, dep dependency.Dep) (task.Status, error) { + fm := manager.NewFileManager(dep, inventory.UserFromContext(ctx)) + zipFilePath := filepath.Join( + m.state.TempPath, + m.state.ArchiveFile, + ) + + m.l.Info("Uploading archive file %s to %s...", zipFilePath, m.state.Dst) + + uri, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf( + "failed to parse dst uri %q: %s (%w)", + m.state.Dst, + err, + queue.CriticalErr, + ) + } + + file, err := os.Open(zipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to open compressed archive %q: %s", m.state.ArchiveFile, err) + } + defer file.Close() + fi, err := file.Stat() + if err != nil { + return task.StatusError, fmt.Errorf("failed to get file info: %w", err) + } + size := fi.Size() + + m.Lock() + m.progress[ProgressTypeUpload] = &queue.Progress{} + m.Unlock() + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: uri, + Size: size, + }, + ProgressFunc: func(current, diff int64, total int64) { + atomic.StoreInt64(&m.progress[ProgressTypeUpload].Current, current) + atomic.StoreInt64(&m.progress[ProgressTypeUpload].Total, total) + }, + File: file, + Seeker: file, + } + + _, err = fm.Update(ctx, fileData) + if err != nil { + return task.StatusError, fmt.Errorf("failed to upload archive file: %w", err) + } + + return task.StatusCompleted, nil +} + +func (m *CreateArchiveTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + + if m.state.NodeState.progress != nil { + merged := make(queue.Progresses) + for k, v := range m.progress { + merged[k] = v + } + + for k, v := range m.state.NodeState.progress { + merged[k] = v + } + + return merged + } + return m.progress +} + +func (m *CreateArchiveTask) Summarize(hasher hashid.Encoder) *queue.Summary { + // unmarshal state + if m.state == nil { + if err := json.Unmarshal([]byte(m.State()), &m.state); err != nil { + return nil + } + } + + failed := m.state.Failed + if m.state.SlaveCompressState != nil { + failed = m.state.SlaveCompressState.Failed + } + + return &queue.Summary{ + NodeID: m.state.NodeID, + Phase: string(m.state.Phase), + Props: map[string]any{ + SummaryKeySrcMultiple: m.state.Uris, + SummaryKeyDst: m.state.Dst, + SummaryKeyFailed: failed, + }, + } +} + +type ( + SlaveCreateArchiveEntity struct { + Entity *ent.Entity `json:"entity"` + Path string `json:"path"` + } + SlaveCreateArchiveTaskState struct { + Entities []SlaveCreateArchiveEntity `json:"entities"` + Policies map[int]*ent.StoragePolicy `json:"policies"` + CompressedSize int64 `json:"compressed_size"` + TempPath string `json:"temp_path"` + ZipFilePath string `json:"zip_file_path"` + Failed int `json:"failed"` + } + SlaveCreateArchiveTask struct { + *queue.InMemoryTask + + mu sync.RWMutex + progress queue.Progresses + l logging.Logger + state *SlaveCreateArchiveTaskState + } +) + +// NewSlaveCreateArchiveTask creates a new SlaveCreateArchiveTask from raw private state +func NewSlaveCreateArchiveTask(ctx context.Context, props *types.SlaveTaskProps, id int, state string) queue.Task { + return &SlaveCreateArchiveTask{ + InMemoryTask: &queue.InMemoryTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + ID: id, + CorrelationID: logging.CorrelationID(ctx), + PublicState: &types.TaskPublicState{ + SlaveTaskProps: props, + }, + PrivateState: state, + }, + }, + }, + + progress: make(queue.Progresses), + } +} + +func (t *SlaveCreateArchiveTask) Do(ctx context.Context) (task.Status, error) { + ctx = prepareSlaveTaskCtx(ctx, t.Model().PublicState.SlaveTaskProps) + dep := dependency.FromContext(ctx) + t.l = dep.Logger() + fm := manager.NewFileManager(dep, nil) + + // unmarshal state + state := &SlaveCreateArchiveTaskState{} + if err := json.Unmarshal([]byte(t.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + t.state = state + + totalFiles := int64(0) + totalFileSize := int64(0) + for _, e := range t.state.Entities { + totalFiles++ + totalFileSize += e.Entity.Size + } + + t.Lock() + t.progress[ProgressTypeArchiveCount] = &queue.Progress{Total: totalFiles} + t.progress[ProgressTypeArchiveSize] = &queue.Progress{Total: totalFileSize} + t.Unlock() + + // 3. Create temp workspace + tempPath, err := prepareTempFolder(ctx, dep, t) + if err != nil { + return task.StatusError, fmt.Errorf("failed to prepare temp folder: %w", err) + } + t.state.TempPath = tempPath + + // 2. Create archive file + fileName := fmt.Sprintf("%s.zip", uuid.Must(uuid.NewV4())) + zipFilePath := filepath.Join( + t.state.TempPath, + fileName, + ) + zipFile, err := util.CreatNestedFile(zipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create zip file: %w", err) + } + + defer zipFile.Close() + + zipWriter := zip.NewWriter(zipFile) + defer zipWriter.Close() + + // 3. Download each entity and write into zip file + for _, e := range t.state.Entities { + policy, ok := t.state.Policies[e.Entity.StoragePolicyEntities] + if !ok { + state.Failed++ + t.l.Warning("Policy not found for entity %d, skipping...", e.Entity.ID) + continue + } + + entity := fs.NewEntity(e.Entity) + es, err := fm.GetEntitySource(ctx, 0, + fs.WithEntity(entity), + fs.WithPolicy(fm.CastStoragePolicyOnSlave(ctx, policy)), + ) + if err != nil { + state.Failed++ + t.l.Warning("Failed to get entity source for entity %d: %s, skipping...", e.Entity.ID, err) + continue + } + + // Write to zip file + header := &zip.FileHeader{ + Name: e.Path, + Modified: entity.UpdatedAt(), + UncompressedSize64: uint64(entity.Size()), + Method: zip.Deflate, + } + + writer, err := zipWriter.CreateHeader(header) + if err != nil { + es.Close() + state.Failed++ + t.l.Warning("Failed to create zip header for %s: %s, skipping...", e.Path, err) + continue + } + + es.Apply(entitysource.WithContext(ctx)) + _, err = io.Copy(writer, es) + es.Close() + if err != nil { + state.Failed++ + t.l.Warning("Failed to write entity %d to zip file: %s, skipping...", e.Entity.ID, err) + } + + atomic.AddInt64(&t.progress[ProgressTypeArchiveSize].Current, entity.Size()) + atomic.AddInt64(&t.progress[ProgressTypeArchiveCount].Current, 1) + } + + zipWriter.Close() + stat, err := zipFile.Stat() + if err != nil { + return task.StatusError, fmt.Errorf("failed to get compressed file info: %w", err) + } + + t.state.CompressedSize = stat.Size() + t.state.ZipFilePath = zipFilePath + // Clear unused fields to save space + t.state.Entities = nil + t.state.Policies = nil + + newStateStr, marshalErr := json.Marshal(t.state) + if marshalErr != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr) + } + + t.Lock() + t.Task.PrivateState = string(newStateStr) + t.Unlock() + return task.StatusCompleted, nil +} + +func (m *SlaveCreateArchiveTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + + return m.progress +} diff --git a/pkg/filemanager/workflows/extract.go b/pkg/filemanager/workflows/extract.go new file mode 100644 index 00000000..14d907e9 --- /dev/null +++ b/pkg/filemanager/workflows/extract.go @@ -0,0 +1,766 @@ +package workflows + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" + "github.com/mholt/archiver/v4" +) + +type ( + ExtractArchiveTask struct { + *queue.DBTask + + l logging.Logger + state *ExtractArchiveTaskState + progress queue.Progresses + node cluster.Node + } + ExtractArchiveTaskPhase string + ExtractArchiveTaskState struct { + Uri string `json:"uri,omitempty"` + Encoding string `json:"encoding,omitempty"` + Dst string `json:"dst,omitempty"` + TempPath string `json:"temp_path,omitempty"` + TempZipFilePath string `json:"temp_zip_file_path,omitempty"` + ProcessedCursor string `json:"processed_cursor,omitempty"` + SlaveTaskID int `json:"slave_task_id,omitempty"` + NodeState `json:",inline"` + Phase ExtractArchiveTaskPhase `json:"phase,omitempty"` + } +) + +const ( + ExtractArchivePhaseNotStarted ExtractArchiveTaskPhase = "" + ExtractArchivePhaseDownloadZip ExtractArchiveTaskPhase = "download_zip" + ExtractArchivePhaseAwaitSlaveComplete ExtractArchiveTaskPhase = "await_slave_complete" + + ProgressTypeExtractCount = "extract_count" + ProgressTypeExtractSize = "extract_size" + ProgressTypeDownload = "download" + + SummaryKeySrc = "src" + SummaryKeyDst = "dst" +) + +func init() { + queue.RegisterResumableTaskFactory(queue.ExtractArchiveTaskType, NewExtractArchiveTaskFromModel) +} + +// NewExtractArchiveTask creates a new ExtractArchiveTask +func NewExtractArchiveTask(ctx context.Context, src, dst, encoding string) (queue.Task, error) { + state := &ExtractArchiveTaskState{ + Uri: src, + Dst: dst, + Encoding: encoding, + NodeState: NodeState{}, + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + t := &ExtractArchiveTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.ExtractArchiveTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{}, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func NewExtractArchiveTaskFromModel(task *ent.Task) queue.Task { + return &ExtractArchiveTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func (m *ExtractArchiveTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + m.l = dep.Logger() + + m.Lock() + if m.progress == nil { + m.progress = make(queue.Progresses) + } + m.Unlock() + + // unmarshal state + state := &ExtractArchiveTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + m.state = state + + // select node + node, err := allocateNode(ctx, dep, &m.state.NodeState, types.NodeCapabilityExtractArchive) + if err != nil { + return task.StatusError, fmt.Errorf("failed to allocate node: %w", err) + } + m.node = node + + next := task.StatusCompleted + + if node.IsMaster() { + switch m.state.Phase { + case ExtractArchivePhaseNotStarted: + next, err = m.masterExtractArchive(ctx, dep) + case ExtractArchivePhaseDownloadZip: + next, err = m.masterDownloadZip(ctx, dep) + default: + next, err = task.StatusError, fmt.Errorf("unknown phase %q: %w", m.state.Phase, queue.CriticalErr) + } + } else { + switch m.state.Phase { + case ExtractArchivePhaseNotStarted: + next, err = m.createSlaveExtractTask(ctx, dep) + case ExtractArchivePhaseAwaitSlaveComplete: + next, err = m.awaitSlaveExtractComplete(ctx, dep) + default: + next, err = task.StatusError, fmt.Errorf("unknown phase %q: %w", m.state.Phase, queue.CriticalErr) + } + } + + newStateStr, marshalErr := json.Marshal(m.state) + if marshalErr != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr) + } + + m.Lock() + m.Task.PrivateState = string(newStateStr) + m.Unlock() + return next, err +} + +func (m *ExtractArchiveTask) createSlaveExtractTask(ctx context.Context, dep dependency.Dep) (task.Status, error) { + uri, err := fs.NewUriFromString(m.state.Uri) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse src uri: %s (%w)", err, queue.CriticalErr) + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + + // Get entity source to extract + archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr) + } + + // Validate file size + if user.Edges.Group.Settings.DecompressSize > 0 && archiveFile.Size() > user.Edges.Group.Settings.DecompressSize { + return task.StatusError, + fmt.Errorf("file size %d exceeds the limit %d (%w)", archiveFile.Size(), user.Edges.Group.Settings.DecompressSize, queue.CriticalErr) + } + + // Create slave task + storagePolicyClient := dep.StoragePolicyClient() + policy, err := storagePolicyClient.GetPolicyByID(ctx, archiveFile.PrimaryEntity().PolicyID()) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get policy: %w", err) + } + + payload := &SlaveExtractArchiveTaskState{ + FileName: archiveFile.DisplayName(), + Entity: archiveFile.PrimaryEntity().Model(), + Policy: policy, + Encoding: m.state.Encoding, + Dst: m.state.Dst, + UserID: user.ID, + } + + payloadStr, err := json.Marshal(payload) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal payload: %w", err) + } + + taskId, err := m.node.CreateTask(ctx, queue.SlaveExtractArchiveType, string(payloadStr)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create slave task: %w", err) + } + + m.state.Phase = ExtractArchivePhaseAwaitSlaveComplete + m.state.SlaveTaskID = taskId + m.ResumeAfter((10 * time.Second)) + return task.StatusSuspending, nil +} + +func (m *ExtractArchiveTask) awaitSlaveExtractComplete(ctx context.Context, dep dependency.Dep) (task.Status, error) { + t, err := m.node.GetTask(ctx, m.state.SlaveTaskID, true) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get slave task: %w", err) + } + + m.Lock() + m.state.NodeState.progress = t.Progress + m.Unlock() + + if t.Status == task.StatusError { + return task.StatusError, fmt.Errorf("slave task failed: %s (%w)", t.Error, queue.CriticalErr) + } + + if t.Status == task.StatusCanceled { + return task.StatusError, fmt.Errorf("slave task canceled (%w)", queue.CriticalErr) + } + + if t.Status == task.StatusCompleted { + return task.StatusCompleted, nil + } + + m.l.Info("Slave task %d is still compressing, resume after 30s.", m.state.SlaveTaskID) + m.ResumeAfter((time.Second * 30)) + return task.StatusSuspending, nil +} + +func (m *ExtractArchiveTask) masterExtractArchive(ctx context.Context, dep dependency.Dep) (task.Status, error) { + uri, err := fs.NewUriFromString(m.state.Uri) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse src uri: %s (%w)", err, queue.CriticalErr) + } + + dst, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse dst uri: %s (%w)", err, queue.CriticalErr) + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + + // Get entity source to extract + archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr) + } + + // Validate file size + if user.Edges.Group.Settings.DecompressSize > 0 && archiveFile.Size() > user.Edges.Group.Settings.DecompressSize { + return task.StatusError, + fmt.Errorf("file size %d exceeds the limit %d (%w)", archiveFile.Size(), user.Edges.Group.Settings.DecompressSize, queue.CriticalErr) + } + + es, err := fm.GetEntitySource(ctx, 0, fs.WithEntity(archiveFile.PrimaryEntity())) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get entity source: %w", err) + } + + defer es.Close() + + m.l.Info("Extracting archive %q to %q", uri, m.state.Dst) + // Identify file format + format, readStream, err := archiver.Identify(archiveFile.DisplayName(), es) + if err != nil { + return task.StatusError, fmt.Errorf("failed to identify archive format: %w", err) + } + + m.l.Info("Archive file %q format identified as %q", uri, format.Name()) + + extractor, ok := format.(archiver.Extractor) + if !ok { + return task.StatusError, fmt.Errorf("format not an extractor %s") + } + + if format.Name() == ".zip" { + // Zip extractor requires a Seeker+ReadAt + if m.state.TempZipFilePath == "" && !es.IsLocal() { + m.state.Phase = ExtractArchivePhaseDownloadZip + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + if m.state.TempZipFilePath != "" { + // Use temp zip file path + zipFile, err := os.Open(m.state.TempZipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to open temp zip file: %w", err) + } + + defer zipFile.Close() + readStream = zipFile + } + + if es.IsLocal() { + if _, err = es.Seek(0, 0); err != nil { + return task.StatusError, fmt.Errorf("failed to seek entity source: %w", err) + } + + readStream = es + } + + if m.state.Encoding != "" { + m.l.Info("Using encoding %q for zip archive", m.state.Encoding) + extractor = archiver.Zip{TextEncoding: m.state.Encoding} + } + } + + needSkipToCursor := false + if m.state.ProcessedCursor != "" { + needSkipToCursor = true + } + m.Lock() + m.progress[ProgressTypeExtractCount] = &queue.Progress{} + m.progress[ProgressTypeExtractSize] = &queue.Progress{} + m.Unlock() + + // extract and upload + err = extractor.Extract(ctx, readStream, nil, func(ctx context.Context, f archiver.File) error { + if needSkipToCursor && f.NameInArchive != m.state.ProcessedCursor { + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + m.l.Info("File %q already processed, skipping...", f.NameInArchive) + return nil + } + + // Found cursor, start from cursor +1 + if m.state.ProcessedCursor == f.NameInArchive { + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + needSkipToCursor = false + return nil + } + + rawPath := util.FormSlash(f.NameInArchive) + savePath := dst.JoinRaw(rawPath) + + // Check if path is legit + if !strings.HasPrefix(savePath.Path(), util.FillSlash(path.Clean(dst.Path()))) { + m.l.Warning("Path %q is not legit, skipping...", f.NameInArchive) + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + return nil + } + + if f.FileInfo.IsDir() { + _, err := fm.Create(ctx, savePath, types.FileTypeFolder) + if err != nil { + m.l.Warning("Failed to create directory %q: %s, skipping...", rawPath, err) + } + + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + m.state.ProcessedCursor = f.NameInArchive + return nil + } + + fileStream, err := f.Open() + if err != nil { + m.l.Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err) + return nil + } + + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: savePath, + Size: f.Size(), + }, + ProgressFunc: func(current, diff int64, total int64) { + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, diff) + }, + File: fileStream, + } + + _, err = fm.Update(ctx, fileData, fs.WithNoEntityType()) + if err != nil { + return fmt.Errorf("failed to upload file %q in archive file: %w", rawPath, err) + } + + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + m.state.ProcessedCursor = f.NameInArchive + return nil + }) + + if err != nil { + return task.StatusError, fmt.Errorf("failed to extract archive: %w", err) + } + + return task.StatusCompleted, nil +} + +func (m *ExtractArchiveTask) masterDownloadZip(ctx context.Context, dep dependency.Dep) (task.Status, error) { + uri, err := fs.NewUriFromString(m.state.Uri) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse src uri: %s (%w)", err, queue.CriticalErr) + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + + // Get entity source to extract + archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr) + } + + es, err := fm.GetEntitySource(ctx, 0, fs.WithEntity(archiveFile.PrimaryEntity())) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get entity source: %w", err) + } + + defer es.Close() + + // For non-local entity, we need to download the whole zip file first + tempPath, err := prepareTempFolder(ctx, dep, m) + if err != nil { + return task.StatusError, fmt.Errorf("failed to prepare temp folder: %w", err) + } + m.state.TempPath = tempPath + + fileName := fmt.Sprintf("%s.zip", uuid.Must(uuid.NewV4())) + zipFilePath := filepath.Join( + m.state.TempPath, + fileName, + ) + + zipFile, err := util.CreatNestedFile(zipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create zip file: %w", err) + } + + m.Lock() + m.progress[ProgressTypeDownload] = &queue.Progress{Total: es.Entity().Size()} + m.Unlock() + + defer zipFile.Close() + if _, err := io.Copy(zipFile, util.NewCallbackReader(es, func(i int64) { + atomic.AddInt64(&m.progress[ProgressTypeDownload].Current, i) + })); err != nil { + zipFile.Close() + if err := os.Remove(zipFilePath); err != nil { + m.l.Warning("Failed to remove temp zip file %q: %s", zipFilePath, err) + } + return task.StatusError, fmt.Errorf("failed to copy zip file to local temp: %w", err) + } + + m.Lock() + delete(m.progress, ProgressTypeDownload) + m.Unlock() + m.state.TempZipFilePath = zipFilePath + m.state.Phase = ExtractArchivePhaseNotStarted + m.ResumeAfter(0) + return task.StatusSuspending, nil +} + +func (m *ExtractArchiveTask) Summarize(hasher hashid.Encoder) *queue.Summary { + if m.state == nil { + if err := json.Unmarshal([]byte(m.State()), &m.state); err != nil { + return nil + } + } + + return &queue.Summary{ + NodeID: m.state.NodeID, + Phase: string(m.state.Phase), + Props: map[string]any{ + SummaryKeySrc: m.state.Uri, + SummaryKeyDst: m.state.Dst, + }, + } +} + +func (m *ExtractArchiveTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + + if m.state.NodeState.progress != nil { + merged := make(queue.Progresses) + for k, v := range m.progress { + merged[k] = v + } + + for k, v := range m.state.NodeState.progress { + merged[k] = v + } + + return merged + } + return m.progress +} + +func (m *ExtractArchiveTask) Cleanup(ctx context.Context) error { + if m.state.TempPath != "" { + time.Sleep(time.Duration(1) * time.Second) + return os.RemoveAll(m.state.TempPath) + } + + return nil +} + +type ( + SlaveExtractArchiveTask struct { + *queue.InMemoryTask + + l logging.Logger + state *SlaveExtractArchiveTaskState + progress queue.Progresses + node cluster.Node + } + + SlaveExtractArchiveTaskState struct { + FileName string `json:"file_name"` + Entity *ent.Entity `json:"entity"` + Policy *ent.StoragePolicy `json:"policy"` + Encoding string `json:"encoding,omitempty"` + Dst string `json:"dst,omitempty"` + UserID int `json:"user_id"` + TempPath string `json:"temp_path,omitempty"` + TempZipFilePath string `json:"temp_zip_file_path,omitempty"` + ProcessedCursor string `json:"processed_cursor,omitempty"` + } +) + +// NewSlaveExtractArchiveTask creates a new SlaveExtractArchiveTask from raw private state +func NewSlaveExtractArchiveTask(ctx context.Context, props *types.SlaveTaskProps, id int, state string) queue.Task { + return &SlaveExtractArchiveTask{ + InMemoryTask: &queue.InMemoryTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + ID: id, + CorrelationID: logging.CorrelationID(ctx), + PublicState: &types.TaskPublicState{ + SlaveTaskProps: props, + }, + PrivateState: state, + }, + }, + }, + + progress: make(queue.Progresses), + } +} + +func (m *SlaveExtractArchiveTask) Do(ctx context.Context) (task.Status, error) { + ctx = prepareSlaveTaskCtx(ctx, m.Model().PublicState.SlaveTaskProps) + dep := dependency.FromContext(ctx) + m.l = dep.Logger() + np, err := dep.NodePool(ctx) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get node pool: %w", err) + } + + m.node, err = np.Get(ctx, types.NodeCapabilityNone, 0) + if err != nil || !m.node.IsMaster() { + return task.StatusError, fmt.Errorf("failed to get master node: %w", err) + } + + fm := manager.NewFileManager(dep, nil) + + // unmarshal state + state := &SlaveExtractArchiveTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + m.state = state + m.Lock() + if m.progress == nil { + m.progress = make(queue.Progresses) + } + m.progress[ProgressTypeExtractCount] = &queue.Progress{} + m.progress[ProgressTypeExtractSize] = &queue.Progress{} + m.Unlock() + + dst, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse dst uri: %s (%w)", err, queue.CriticalErr) + } + + // 1. Get entity source + entity := fs.NewEntity(m.state.Entity) + es, err := fm.GetEntitySource(ctx, 0, fs.WithEntity(entity), fs.WithPolicy(fm.CastStoragePolicyOnSlave(ctx, m.state.Policy))) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get entity source: %w", err) + } + + defer es.Close() + + // 2. Identify file format + format, readStream, err := archiver.Identify(m.state.FileName, es) + if err != nil { + return task.StatusError, fmt.Errorf("failed to identify archive format: %w", err) + } + m.l.Info("Archive file %q format identified as %q", m.state.FileName, format.Name()) + + extractor, ok := format.(archiver.Extractor) + if !ok { + return task.StatusError, fmt.Errorf("format not an extractor %s") + } + + if format.Name() == ".zip" { + if _, err = es.Seek(0, 0); err != nil { + return task.StatusError, fmt.Errorf("failed to seek entity source: %w", err) + } + + if m.state.TempZipFilePath == "" && !es.IsLocal() { + tempPath, err := prepareTempFolder(ctx, dep, m) + if err != nil { + return task.StatusError, fmt.Errorf("failed to prepare temp folder: %w", err) + } + m.state.TempPath = tempPath + + fileName := fmt.Sprintf("%s.zip", uuid.Must(uuid.NewV4())) + zipFilePath := filepath.Join( + m.state.TempPath, + fileName, + ) + zipFile, err := util.CreatNestedFile(zipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create zip file: %w", err) + } + + m.Lock() + m.progress[ProgressTypeDownload] = &queue.Progress{Total: es.Entity().Size()} + m.Unlock() + + defer zipFile.Close() + if _, err := io.Copy(zipFile, util.NewCallbackReader(es, func(i int64) { + atomic.AddInt64(&m.progress[ProgressTypeDownload].Current, i) + })); err != nil { + return task.StatusError, fmt.Errorf("failed to copy zip file to local temp: %w", err) + } + + zipFile.Close() + m.state.TempZipFilePath = zipFilePath + } + + if es.IsLocal() { + readStream = es + } else if m.state.TempZipFilePath != "" { + // Use temp zip file path + zipFile, err := os.Open(m.state.TempZipFilePath) + if err != nil { + return task.StatusError, fmt.Errorf("failed to open temp zip file: %w", err) + } + + defer zipFile.Close() + readStream = zipFile + } + + if es.IsLocal() { + readStream = es + } + + if m.state.Encoding != "" { + m.l.Info("Using encoding %q for zip archive", m.state.Encoding) + extractor = archiver.Zip{TextEncoding: m.state.Encoding} + } + } + + needSkipToCursor := false + if m.state.ProcessedCursor != "" { + needSkipToCursor = true + } + + // 3. Extract and upload + err = extractor.Extract(ctx, readStream, nil, func(ctx context.Context, f archiver.File) error { + if needSkipToCursor && f.NameInArchive != m.state.ProcessedCursor { + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + m.l.Info("File %q already processed, skipping...", f.NameInArchive) + return nil + } + + // Found cursor, start from cursor +1 + if m.state.ProcessedCursor == f.NameInArchive { + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + needSkipToCursor = false + return nil + } + + rawPath := util.FormSlash(f.NameInArchive) + savePath := dst.JoinRaw(rawPath) + + // Check if path is legit + if !strings.HasPrefix(savePath.Path(), util.FillSlash(path.Clean(dst.Path()))) { + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, f.Size()) + m.l.Warning("Path %q is not legit, skipping...", f.NameInArchive) + return nil + } + + if f.FileInfo.IsDir() { + _, err := fm.Create(ctx, savePath, types.FileTypeFolder, fs.WithNode(m.node), fs.WithStatelessUserID(m.state.UserID)) + if err != nil { + m.l.Warning("Failed to create directory %q: %s, skipping...", rawPath, err) + } + + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + m.state.ProcessedCursor = f.NameInArchive + return nil + } + + fileStream, err := f.Open() + if err != nil { + m.l.Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err) + return nil + } + + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: savePath, + Size: f.Size(), + }, + ProgressFunc: func(current, diff int64, total int64) { + atomic.AddInt64(&m.progress[ProgressTypeExtractSize].Current, diff) + }, + File: fileStream, + } + + _, err = fm.Update(ctx, fileData, fs.WithNode(m.node), fs.WithStatelessUserID(m.state.UserID), fs.WithNoEntityType()) + if err != nil { + return fmt.Errorf("failed to upload file %q in archive file: %w", rawPath, err) + } + + atomic.AddInt64(&m.progress[ProgressTypeExtractCount].Current, 1) + m.state.ProcessedCursor = f.NameInArchive + return nil + }) + + if err != nil { + return task.StatusError, fmt.Errorf("failed to extract archive: %w", err) + } + + return task.StatusCompleted, nil +} + +func (m *SlaveExtractArchiveTask) Cleanup(ctx context.Context) error { + if m.state.TempPath != "" { + time.Sleep(time.Duration(1) * time.Second) + return os.RemoveAll(m.state.TempPath) + } + + return nil +} + +func (m *SlaveExtractArchiveTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + return m.progress +} diff --git a/pkg/filemanager/workflows/remote_download.go b/pkg/filemanager/workflows/remote_download.go new file mode 100644 index 00000000..dc7e8fe8 --- /dev/null +++ b/pkg/filemanager/workflows/remote_download.go @@ -0,0 +1,657 @@ +package workflows + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +type ( + RemoteDownloadTask struct { + *queue.DBTask + + l logging.Logger + state *RemoteDownloadTaskState + node cluster.Node + d downloader.Downloader + progress queue.Progresses + } + RemoteDownloadTaskPhase string + RemoteDownloadTaskState struct { + SrcFileUri string `json:"src_file_uri,omitempty"` + SrcUri string `json:"src_uri,omitempty"` + Dst string `json:"dst,omitempty"` + Handle *downloader.TaskHandle `json:"handle,omitempty"` + Status *downloader.TaskStatus `json:"status,omitempty"` + NodeState `json:",inline"` + Phase RemoteDownloadTaskPhase `json:"phase,omitempty"` + SlaveUploadTaskID int `json:"slave__upload_task_id,omitempty"` + SlaveUploadState *SlaveUploadTaskState `json:"slave_upload_state,omitempty"` + GetTaskStatusTried int `json:"get_task_status_tried,omitempty"` + Transferred map[int]interface{} `json:"transferred,omitempty"` + Failed int `json:"failed,omitempty"` + } +) + +const ( + RemoteDownloadTaskPhaseNotStarted RemoteDownloadTaskPhase = "" + RemoteDownloadTaskPhaseMonitor = "monitor" + RemoteDownloadTaskPhaseTransfer = "transfer" + RemoteDownloadTaskPhaseAwaitSeeding = "seeding" + + GetTaskStatusMaxTries = 5 + + SummaryKeyDownloadStatus = "download" + SummaryKeySrcStr = "src_str" + + ProgressTypeRelocateTransferCount = "relocate" + ProgressTypeUploadSinglePrefix = "upload_single_" + + SummaryKeySrcMultiple = "src_multiple" + SummaryKeySrcDstPolicyID = "dst_policy_id" + SummaryKeyFailed = "failed" +) + +func init() { + queue.RegisterResumableTaskFactory(queue.RemoteDownloadTaskType, NewRemoteDownloadTaskFromModel) +} + +// NewRemoteDownloadTask creates a new RemoteDownloadTask +func NewRemoteDownloadTask(ctx context.Context, src string, srcFile, dst string) (queue.Task, error) { + state := &RemoteDownloadTaskState{ + SrcUri: src, + SrcFileUri: srcFile, + Dst: dst, + NodeState: NodeState{}, + } + stateBytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("failed to marshal state: %w", err) + } + + t := &RemoteDownloadTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + Type: queue.RemoteDownloadTaskType, + CorrelationID: logging.CorrelationID(ctx), + PrivateState: string(stateBytes), + PublicState: &types.TaskPublicState{}, + }, + DirectOwner: inventory.UserFromContext(ctx), + }, + } + return t, nil +} + +func NewRemoteDownloadTaskFromModel(task *ent.Task) queue.Task { + return &RemoteDownloadTask{ + DBTask: &queue.DBTask{ + Task: task, + }, + } +} + +func (m *RemoteDownloadTask) Do(ctx context.Context) (task.Status, error) { + dep := dependency.FromContext(ctx) + m.l = dep.Logger() + + // unmarshal state + state := &RemoteDownloadTaskState{} + if err := json.Unmarshal([]byte(m.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + m.state = state + + // select node + node, err := allocateNode(ctx, dep, &m.state.NodeState, types.NodeCapabilityRemoteDownload) + if err != nil { + return task.StatusError, fmt.Errorf("failed to allocate node: %w", err) + } + m.node = node + + // create downloader instance + if m.d == nil { + d, err := node.CreateDownloader(ctx, dep.RequestClient(), dep.SettingProvider()) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create downloader: %w", err) + } + + m.d = d + } + + next := task.StatusCompleted + switch m.state.Phase { + case RemoteDownloadTaskPhaseNotStarted: + next, err = m.createDownloadTask(ctx, dep) + case RemoteDownloadTaskPhaseMonitor, RemoteDownloadTaskPhaseAwaitSeeding: + next, err = m.monitor(ctx, dep) + case RemoteDownloadTaskPhaseTransfer: + if m.node.IsMaster() { + next, err = m.masterTransfer(ctx, dep) + } else { + next, err = m.slaveTransfer(ctx, dep) + } + } + + newStateStr, marshalErr := json.Marshal(m.state) + if marshalErr != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr) + } + + m.Lock() + m.Task.PrivateState = string(newStateStr) + m.Unlock() + return next, err +} + +func (m *RemoteDownloadTask) createDownloadTask(ctx context.Context, dep dependency.Dep) (task.Status, error) { + if m.state.Handle != nil { + m.state.Phase = RemoteDownloadTaskPhaseMonitor + return task.StatusSuspending, nil + } + + user := inventory.UserFromContext(ctx) + torrentUrl := m.state.SrcUri + if m.state.SrcFileUri != "" { + // Target is a torrent file + uri, err := fs.NewUriFromString(m.state.SrcFileUri) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse src file uri: %s (%w)", err, queue.CriticalErr) + } + + fm := manager.NewFileManager(dep, user) + expire := time.Now().Add(dep.SettingProvider().EntityUrlValidDuration(ctx)) + torrentUrls, _, err := fm.GetEntityUrls(ctx, []manager.GetEntityUrlArgs{ + {URI: uri}, + }, fs.WithUrlExpire(&expire)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get torrent entity urls: %w", err) + } + + if len(torrentUrls) == 0 { + return task.StatusError, fmt.Errorf("no torrent urls found") + } + + torrentUrl = torrentUrls[0] + } + + // Create download task + handle, err := m.d.CreateTask(ctx, torrentUrl, user.Edges.Group.Settings.RemoteDownloadOptions) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create download task: %w", err) + } + + m.state.Handle = handle + m.state.Phase = RemoteDownloadTaskPhaseMonitor + return task.StatusSuspending, nil +} + +func (m *RemoteDownloadTask) monitor(ctx context.Context, dep dependency.Dep) (task.Status, error) { + resumeAfter := time.Duration(m.node.Settings(ctx).Interval) * time.Second + + // Update task status + status, err := m.d.Info(ctx, m.state.Handle) + if err != nil { + if errors.Is(err, downloader.ErrTaskNotFount) && m.state.Status != nil { + // If task is not found, but it previously existed, consider it as canceled + m.l.Warning("task not found, consider it as canceled") + return task.StatusCanceled, nil + } + + m.state.GetTaskStatusTried++ + if m.state.GetTaskStatusTried >= GetTaskStatusMaxTries { + return task.StatusError, fmt.Errorf("failed to get task status after %d retry: %w", m.state.GetTaskStatusTried, err) + } + + m.l.Warning("failed to get task info: %s, will retry.", err) + m.ResumeAfter(resumeAfter) + return task.StatusSuspending, nil + } + + // Follow to new handle if needed + if status.FollowedBy != nil { + m.l.Info("Task handle updated to %v", status.FollowedBy) + m.state.Handle = status.FollowedBy + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + if m.state.Status == nil || m.state.Status.Total != status.Total { + m.l.Info("download size changed, re-validate files.") + // First time to get status / total size changed, check user capacity + if err := m.validateFiles(ctx, dep, status); err != nil { + m.state.Status = status + return task.StatusError, fmt.Errorf("failed to validate files: %s (%w)", err, queue.CriticalErr) + } + } + + m.state.Status = status + m.state.GetTaskStatusTried = 0 + + m.l.Debug("Monitor %q task state: %s", status.Name, status.State) + switch status.State { + case downloader.StatusSeeding: + m.l.Info("Download task seeding") + if m.state.Phase == RemoteDownloadTaskPhaseMonitor { + // Not transferred + m.state.Phase = RemoteDownloadTaskPhaseTransfer + return task.StatusSuspending, nil + } else if !m.node.Settings(ctx).WaitForSeeding { + // Skip seeding + m.l.Info("Download task seeding skipped.") + return task.StatusCompleted, nil + } else { + // Still seeding + m.ResumeAfter(resumeAfter) + return task.StatusSuspending, nil + } + case downloader.StatusCompleted: + m.l.Info("Download task completed") + if m.state.Phase == RemoteDownloadTaskPhaseMonitor { + // Not transferred + m.state.Phase = RemoteDownloadTaskPhaseTransfer + return task.StatusSuspending, nil + } + // Seeding complete + m.l.Info("Download task seeding completed") + return task.StatusCompleted, nil + case downloader.StatusDownloading: + m.ResumeAfter(resumeAfter) + return task.StatusSuspending, nil + case downloader.StatusUnknown, downloader.StatusError: + return task.StatusError, fmt.Errorf("download task failed with state %q (%w)", status.State, queue.CriticalErr) + } + + m.ResumeAfter(resumeAfter) + return task.StatusSuspending, nil +} + +func (m *RemoteDownloadTask) slaveTransfer(ctx context.Context, dep dependency.Dep) (task.Status, error) { + u := inventory.UserFromContext(ctx) + if m.state.Transferred == nil { + m.state.Transferred = make(map[int]interface{}) + } + + if m.state.SlaveUploadTaskID == 0 { + dstUri, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse dst uri %q: %s (%w)", m.state.Dst, err, queue.CriticalErr) + } + + // Create slave upload task + payload := &SlaveUploadTaskState{ + Files: []SlaveUploadEntity{}, + MaxParallel: dep.SettingProvider().MaxParallelTransfer(ctx), + UserID: u.ID, + } + + // Construct files to be transferred + for _, f := range m.state.Status.Files { + if !f.Selected { + continue + } + + // Skip already transferred + if _, ok := m.state.Transferred[f.Index]; ok { + continue + } + + dst := dstUri.JoinRaw(f.Name) + src := filepath.FromSlash(path.Join(m.state.Status.SavePath, f.Name)) + payload.Files = append(payload.Files, SlaveUploadEntity{ + Src: src, + Uri: dst, + Size: f.Size, + Index: f.Index, + }) + } + + payloadStr, err := json.Marshal(payload) + if err != nil { + return task.StatusError, fmt.Errorf("failed to marshal payload: %w", err) + } + + taskId, err := m.node.CreateTask(ctx, queue.SlaveUploadTaskType, string(payloadStr)) + if err != nil { + return task.StatusError, fmt.Errorf("failed to create slave task: %w", err) + } + + m.state.NodeState.progress = nil + m.state.SlaveUploadTaskID = taskId + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + + m.l.Info("Checking slave upload task %d...", m.state.SlaveUploadTaskID) + t, err := m.node.GetTask(ctx, m.state.SlaveUploadTaskID, true) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get slave task: %w", err) + } + + m.Lock() + m.state.NodeState.progress = t.Progress + m.Unlock() + + m.state.SlaveUploadState = &SlaveUploadTaskState{} + if err := json.Unmarshal([]byte(t.PrivateState), m.state.SlaveUploadState); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal slave compress state: %s (%w)", err, queue.CriticalErr) + } + + if t.Status == task.StatusError || t.Status == task.StatusCompleted { + if len(m.state.SlaveUploadState.Transferred) < len(m.state.SlaveUploadState.Files) { + // Not all files transferred, retry + slaveTaskId := m.state.SlaveUploadTaskID + m.state.SlaveUploadTaskID = 0 + for i, _ := range m.state.SlaveUploadState.Transferred { + m.state.Transferred[m.state.SlaveUploadState.Files[i].Index] = struct{}{} + } + + m.l.Warning("Slave task %d failed to transfer %d files, retrying...", slaveTaskId, len(m.state.SlaveUploadState.Files)-len(m.state.SlaveUploadState.Transferred)) + return task.StatusError, fmt.Errorf( + "slave task failed to transfer %d files, first 5 errors: %s", + len(m.state.SlaveUploadState.Files)-len(m.state.SlaveUploadState.Transferred), + m.state.SlaveUploadState.First5TransferErrors, + ) + } else { + m.state.Phase = RemoteDownloadTaskPhaseAwaitSeeding + m.ResumeAfter(0) + return task.StatusSuspending, nil + } + } + + if t.Status == task.StatusCanceled { + return task.StatusError, fmt.Errorf("slave task canceled (%w)", queue.CriticalErr) + } + + m.l.Info("Slave task %d is still uploading, resume after 30s.", m.state.SlaveUploadTaskID) + m.ResumeAfter(time.Second * 30) + return task.StatusSuspending, nil +} + +func (m *RemoteDownloadTask) masterTransfer(ctx context.Context, dep dependency.Dep) (task.Status, error) { + if m.state.Transferred == nil { + m.state.Transferred = make(map[int]interface{}) + } + + maxParallel := dep.SettingProvider().MaxParallelTransfer(ctx) + wg := sync.WaitGroup{} + worker := make(chan int, maxParallel) + for i := 0; i < maxParallel; i++ { + worker <- i + } + + // Sum up total count and select files + totalCount := 0 + totalSize := int64(0) + allFiles := make([]downloader.TaskFile, 0, len(m.state.Status.Files)) + for _, f := range m.state.Status.Files { + if f.Selected { + allFiles = append(allFiles, f) + totalSize += f.Size + totalCount++ + } + } + + m.Lock() + m.progress = make(queue.Progresses) + m.progress[ProgressTypeUploadCount] = &queue.Progress{Total: int64(totalCount)} + m.progress[ProgressTypeUpload] = &queue.Progress{Total: totalSize} + m.Unlock() + + dstUri, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return task.StatusError, fmt.Errorf("failed to parse dst uri: %s (%w)", err, queue.CriticalErr) + } + + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + failed := int64(0) + ae := serializer.NewAggregateError() + + transferFunc := func(workerId int, file downloader.TaskFile) { + defer func() { + atomic.AddInt64(&m.progress[ProgressTypeUploadCount].Current, 1) + worker <- workerId + wg.Done() + }() + + dst := dstUri.JoinRaw(file.Name) + src := filepath.FromSlash(path.Join(m.state.Status.SavePath, file.Name)) + m.l.Info("Uploading file %s to %s...", src, file.Name, dst) + + progressKey := fmt.Sprintf("%s%d", ProgressTypeUploadSinglePrefix, workerId) + m.Lock() + m.progress[progressKey] = &queue.Progress{Identifier: dst.String(), Total: file.Size} + m.Unlock() + + fileStream, err := os.Open(src) + if err != nil { + m.l.Warning("Failed to open file %s: %s", src, err.Error()) + atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size) + atomic.AddInt64(&failed, 1) + ae.Add(file.Name, fmt.Errorf("failed to open file: %w", err)) + return + } + + defer fileStream.Close() + + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: dst, + Size: file.Size, + }, + ProgressFunc: func(current, diff int64, total int64) { + atomic.AddInt64(&m.progress[progressKey].Current, diff) + atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, diff) + }, + File: fileStream, + } + + _, err = fm.Update(ctx, fileData, fs.WithNoEntityType()) + if err != nil { + m.l.Warning("Failed to upload file %s: %s", src, err.Error()) + atomic.AddInt64(&failed, 1) + atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size) + ae.Add(file.Name, fmt.Errorf("failed to upload file: %w", err)) + return + } + + m.Lock() + m.state.Transferred[file.Index] = nil + m.Unlock() + } + + // Start upload files + for _, file := range allFiles { + // Check if file is already transferred + if _, ok := m.state.Transferred[file.Index]; ok { + m.l.Info("File %s already transferred, skipping...", file.Name) + atomic.AddInt64(&m.progress[ProgressTypeUpload].Current, file.Size) + atomic.AddInt64(&m.progress[ProgressTypeUploadCount].Current, 1) + continue + } + + select { + case <-ctx.Done(): + return task.StatusError, ctx.Err() + case workerId := <-worker: + wg.Add(1) + + go transferFunc(workerId, file) + } + } + + wg.Wait() + if failed > 0 { + m.state.Failed = int(failed) + m.l.Error("Failed to transfer %d file(s).", failed) + return task.StatusError, fmt.Errorf("failed to transfer %d file(s), first 5 errors: %s", failed, ae.FormatFirstN(5)) + } + + m.l.Info("All files transferred.") + m.state.Phase = RemoteDownloadTaskPhaseAwaitSeeding + return task.StatusSuspending, nil +} + +func (m *RemoteDownloadTask) awaitSeeding(ctx context.Context, dep dependency.Dep) (task.Status, error) { + return task.StatusSuspending, nil +} + +func (m *RemoteDownloadTask) validateFiles(ctx context.Context, dep dependency.Dep, status *downloader.TaskStatus) error { + // Validate files + user := inventory.UserFromContext(ctx) + fm := manager.NewFileManager(dep, user) + + dstUri, err := fs.NewUriFromString(m.state.Dst) + if err != nil { + return fmt.Errorf("failed to parse dst uri: %w", err) + } + + selectedFiles := lo.Filter(status.Files, func(f downloader.TaskFile, _ int) bool { + return f.Selected + }) + if len(selectedFiles) == 0 { + return fmt.Errorf("no selected file found in download task") + } + + // find the first valid file + var placeholderFileName string + for _, f := range selectedFiles { + if f.Name != "" { + placeholderFileName = f.Name + break + } + } + + if placeholderFileName == "" { + // File name not available yet, generate one + m.l.Debug("File name not available yet, generate one to validate the destination") + placeholderFileName = uuid.Must(uuid.NewV4()).String() + } + + // Create a placeholder file then delete it to validate the destination + session, err := fm.PrepareUpload(ctx, &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: dstUri.Join(path.Base(placeholderFileName)), + Size: status.Total, + UploadSessionID: uuid.Must(uuid.NewV4()).String(), + ExpireAt: time.Now().Add(time.Second * 3600), + }, + }) + if err != nil { + return err + } + + fm.OnUploadFailed(ctx, session) + return nil +} + +func (m *RemoteDownloadTask) Cleanup(ctx context.Context) error { + if m.state.Handle != nil { + if err := m.d.Cancel(ctx, m.state.Handle); err != nil { + m.l.Warning("failed to cancel download task: %s", err) + } + } + + if m.state.Status != nil && m.node.IsMaster() && m.state.Status.SavePath != "" { + if err := os.RemoveAll(m.state.Status.SavePath); err != nil { + m.l.Warning("failed to remove download temp folder: %s", err) + } + } + + return nil +} + +// SetDownloadTarget sets the files to download for the task +func (m *RemoteDownloadTask) SetDownloadTarget(ctx context.Context, args ...*downloader.SetFileToDownloadArgs) error { + if m.state.Handle == nil { + return fmt.Errorf("download task not created") + } + + return m.d.SetFilesToDownload(ctx, m.state.Handle, args...) +} + +// CancelDownload cancels the download task +func (m *RemoteDownloadTask) CancelDownload(ctx context.Context) error { + if m.state.Handle == nil { + return nil + } + + return m.d.Cancel(ctx, m.state.Handle) +} + +func (m *RemoteDownloadTask) Summarize(hasher hashid.Encoder) *queue.Summary { + // unmarshal state + if m.state == nil { + if err := json.Unmarshal([]byte(m.State()), &m.state); err != nil { + return nil + } + } + + var status *downloader.TaskStatus + if m.state.Status != nil { + status = &*m.state.Status + + // Redact save path + status.SavePath = "" + } + + failed := m.state.Failed + if m.state.SlaveUploadState != nil && m.state.Phase != RemoteDownloadTaskPhaseTransfer { + failed = len(m.state.SlaveUploadState.Files) - len(m.state.SlaveUploadState.Transferred) + } + + return &queue.Summary{ + Phase: string(m.state.Phase), + NodeID: m.state.NodeID, + Props: map[string]any{ + SummaryKeySrcStr: m.state.SrcUri, + SummaryKeySrc: m.state.SrcFileUri, + SummaryKeyDst: m.state.Dst, + SummaryKeyFailed: failed, + SummaryKeyDownloadStatus: status, + }, + } +} + +func (m *RemoteDownloadTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + + if m.state.NodeState.progress != nil { + merged := make(queue.Progresses) + for k, v := range m.progress { + merged[k] = v + } + + for k, v := range m.state.NodeState.progress { + merged[k] = v + } + + return merged + } + return m.progress +} diff --git a/pkg/filemanager/workflows/upload.go b/pkg/filemanager/workflows/upload.go new file mode 100644 index 00000000..01841785 --- /dev/null +++ b/pkg/filemanager/workflows/upload.go @@ -0,0 +1,224 @@ +package workflows + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "sync/atomic" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" +) + +type ( + SlaveUploadEntity struct { + Uri *fs.URI `json:"uri"` + Src string `json:"src"` + Size int64 `json:"size"` + Index int `json:"index"` + } + SlaveUploadTaskState struct { + MaxParallel int `json:"max_parallel"` + Files []SlaveUploadEntity `json:"files"` + Transferred map[int]interface{} `json:"transferred"` + UserID int `json:"user_id"` + First5TransferErrors string `json:"first_5_transfer_errors,omitempty"` + } + SlaveUploadTask struct { + *queue.InMemoryTask + + progress queue.Progresses + l logging.Logger + state *SlaveUploadTaskState + node cluster.Node + } +) + +// NewSlaveUploadTask creates a new SlaveUploadTask from raw private state +func NewSlaveUploadTask(ctx context.Context, props *types.SlaveTaskProps, id int, state string) queue.Task { + return &SlaveUploadTask{ + InMemoryTask: &queue.InMemoryTask{ + DBTask: &queue.DBTask{ + Task: &ent.Task{ + ID: id, + CorrelationID: logging.CorrelationID(ctx), + PublicState: &types.TaskPublicState{ + SlaveTaskProps: props, + }, + PrivateState: state, + }, + }, + }, + + progress: make(queue.Progresses), + } +} + +func (t *SlaveUploadTask) Do(ctx context.Context) (task.Status, error) { + ctx = prepareSlaveTaskCtx(ctx, t.Model().PublicState.SlaveTaskProps) + dep := dependency.FromContext(ctx) + t.l = dep.Logger() + + np, err := dep.NodePool(ctx) + if err != nil { + return task.StatusError, fmt.Errorf("failed to get node pool: %w", err) + } + + t.node, err = np.Get(ctx, types.NodeCapabilityNone, 0) + if err != nil || !t.node.IsMaster() { + return task.StatusError, fmt.Errorf("failed to get master node: %w", err) + } + + fm := manager.NewFileManager(dep, nil) + + // unmarshal state + state := &SlaveUploadTaskState{} + if err := json.Unmarshal([]byte(t.State()), state); err != nil { + return task.StatusError, fmt.Errorf("failed to unmarshal state: %w", err) + } + + t.state = state + if t.state.Transferred == nil { + t.state.Transferred = make(map[int]interface{}) + } + + wg := sync.WaitGroup{} + worker := make(chan int, t.state.MaxParallel) + for i := 0; i < t.state.MaxParallel; i++ { + worker <- i + } + + // Sum up total count + totalCount := 0 + totalSize := int64(0) + for _, res := range state.Files { + totalSize += res.Size + totalCount++ + } + t.Lock() + t.progress[ProgressTypeUploadCount] = &queue.Progress{} + t.progress[ProgressTypeUpload] = &queue.Progress{} + t.Unlock() + atomic.StoreInt64(&t.progress[ProgressTypeUploadCount].Total, int64(totalCount)) + atomic.StoreInt64(&t.progress[ProgressTypeUpload].Total, totalSize) + ae := serializer.NewAggregateError() + transferFunc := func(workerId, fileId int, file SlaveUploadEntity) { + defer func() { + atomic.AddInt64(&t.progress[ProgressTypeUploadCount].Current, 1) + worker <- workerId + wg.Done() + }() + + t.l.Info("Uploading file %s to %s...", file.Src, file.Uri.String()) + + progressKey := fmt.Sprintf("%s%d", ProgressTypeUploadSinglePrefix, workerId) + t.Lock() + t.progress[progressKey] = &queue.Progress{Identifier: file.Uri.String(), Total: file.Size} + t.Unlock() + + handle, err := os.Open(file.Src) + if err != nil { + t.l.Warning("Failed to open file %s: %s", file.Src, err.Error()) + atomic.AddInt64(&t.progress[ProgressTypeUpload].Current, file.Size) + ae.Add(filepath.Base(file.Src), fmt.Errorf("failed to open file: %w", err)) + return + } + + stat, err := handle.Stat() + if err != nil { + t.l.Warning("Failed to get file stat for %s: %s", file.Src, err.Error()) + handle.Close() + atomic.AddInt64(&t.progress[ProgressTypeUpload].Current, file.Size) + ae.Add(filepath.Base(file.Src), fmt.Errorf("failed to get file stat: %w", err)) + return + } + + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: file.Uri, + Size: stat.Size(), + }, + ProgressFunc: func(current, diff int64, total int64) { + atomic.AddInt64(&t.progress[progressKey].Current, diff) + atomic.AddInt64(&t.progress[ProgressTypeUpload].Current, diff) + atomic.StoreInt64(&t.progress[progressKey].Total, total) + }, + File: handle, + Seeker: handle, + } + + _, err = fm.Update(ctx, fileData, fs.WithNode(t.node), fs.WithStatelessUserID(t.state.UserID), fs.WithNoEntityType()) + if err != nil { + handle.Close() + t.l.Warning("Failed to upload file %s: %s", file.Src, err.Error()) + atomic.AddInt64(&t.progress[ProgressTypeUpload].Current, file.Size) + ae.Add(filepath.Base(file.Src), fmt.Errorf("failed to upload file: %w", err)) + return + } + + t.Lock() + t.state.Transferred[fileId] = nil + t.Unlock() + handle.Close() + } + + // Start upload files + for fileId, file := range t.state.Files { + // Check if file is already transferred + if _, ok := t.state.Transferred[fileId]; ok { + t.l.Info("File %s already transferred, skipping...", file.Src) + atomic.AddInt64(&t.progress[ProgressTypeUpload].Current, file.Size) + atomic.AddInt64(&t.progress[ProgressTypeUploadCount].Current, 1) + continue + } + + select { + case <-ctx.Done(): + return task.StatusError, ctx.Err() + case workerId := <-worker: + wg.Add(1) + + go transferFunc(workerId, fileId, file) + } + } + + wg.Wait() + + t.state.First5TransferErrors = ae.FormatFirstN(5) + newStateStr, marshalErr := json.Marshal(t.state) + if marshalErr != nil { + return task.StatusError, fmt.Errorf("failed to marshal state: %w", marshalErr) + } + t.Lock() + t.Task.PrivateState = string(newStateStr) + t.Unlock() + + // If all files are failed to transfer, return error + if len(t.state.Transferred) != len(t.state.Files) { + t.l.Warning("%d files not transferred", len(t.state.Files)-len(t.state.Transferred)) + if len(t.state.Transferred) == 0 { + return task.StatusError, fmt.Errorf("all file failed to transfer") + } + + } + + return task.StatusCompleted, nil +} + +func (m *SlaveUploadTask) Progress(ctx context.Context) queue.Progresses { + m.Lock() + defer m.Unlock() + + return m.progress +} diff --git a/pkg/filemanager/workflows/worfklows.go b/pkg/filemanager/workflows/worfklows.go new file mode 100644 index 00000000..475dd9be --- /dev/null +++ b/pkg/filemanager/workflows/worfklows.go @@ -0,0 +1,62 @@ +package workflows + +import ( + "context" + "fmt" + "path" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" +) + +const ( + TaskTempPath = "fm_workflows" + slaveProgressRefreshInterval = 5 * time.Second +) + +type NodeState struct { + NodeID int `json:"node_id"` + + progress queue.Progresses +} + +// allocateNode allocates a node for the task. +func allocateNode(ctx context.Context, dep dependency.Dep, state *NodeState, capability types.NodeCapability) (cluster.Node, error) { + np, err := dep.NodePool(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get node pool: %w", err) + } + + node, err := np.Get(ctx, capability, state.NodeID) + if err != nil { + return nil, fmt.Errorf("failed to get node: %w", err) + } + + state.NodeID = node.ID() + return node, nil +} + +// prepareSlaveTaskCtx prepares the context for the slave task. +func prepareSlaveTaskCtx(ctx context.Context, props *types.SlaveTaskProps) context.Context { + ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, strconv.Itoa(props.NodeID)) + ctx = context.WithValue(ctx, cluster.MasterSiteUrlCtx{}, props.MasterSiteURl) + ctx = context.WithValue(ctx, cluster.MasterSiteVersionCtx{}, props.MasterSiteVersion) + ctx = context.WithValue(ctx, cluster.MasterSiteIDCtx{}, props.MasterSiteID) + return ctx +} + +func prepareTempFolder(ctx context.Context, dep dependency.Dep, t queue.Task) (string, error) { + settings := dep.SettingProvider() + tempPath := util.DataPath(path.Join(settings.TempPath(ctx), TaskTempPath, strconv.Itoa(t.ID()))) + if err := util.CreatNestedFolder(tempPath); err != nil { + return "", fmt.Errorf("failed to create temp folder: %w", err) + } + + dep.Logger().Info("Temp folder created: %s", tempPath) + return tempPath, nil +} diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go deleted file mode 100644 index 78fc45fd..00000000 --- a/pkg/filesystem/archive.go +++ /dev/null @@ -1,316 +0,0 @@ -package filesystem - -import ( - "archive/zip" - "context" - "fmt" - "io" - "os" - "path" - "path/filepath" - "strings" - "sync" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/mholt/archiver/v4" -) - -/* =============== - 压缩/解压缩 - =============== -*/ - -// Compress 创建给定目录和文件的压缩文件 -func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error { - // 查找待压缩目录 - folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID) - if err != nil && len(folderIDs) != 0 { - return ErrDBListObjects - } - - // 查找待压缩文件 - files, err := model.GetFilesByIDs(fileIDs, fs.User.ID) - if err != nil && len(fileIDs) != 0 { - return ErrDBListObjects - } - - // 如果上下文限制了父目录,则进行检查 - if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok { - // 检查目录 - for _, folder := range folders { - if *folder.ParentID != parent.ID { - return ErrObjectNotExist - } - } - - // 检查文件 - for _, file := range files { - if file.FolderID != parent.ID { - return ErrObjectNotExist - } - } - } - - // 尝试获取请求上下文,以便于后续检查用户取消任务 - reqContext := ctx - ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context) - if ok { - reqContext = ginCtx.Request.Context() - } - - // 将顶级待处理对象的路径设为根路径 - for i := 0; i < len(folders); i++ { - folders[i].Position = "" - } - for i := 0; i < len(files); i++ { - files[i].Position = "" - } - - // 创建压缩文件Writer - zipWriter := zip.NewWriter(writer) - defer zipWriter.Close() - - ctx = reqContext - - // 压缩各个目录及文件 - for i := 0; i < len(folders); i++ { - select { - case <-reqContext.Done(): - // 取消压缩请求 - return ErrClientCanceled - default: - fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive) - } - - } - for i := 0; i < len(files); i++ { - select { - case <-reqContext.Done(): - // 取消压缩请求 - return ErrClientCanceled - default: - fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive) - } - } - - return nil -} - -func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) { - // 如果对象是文件 - if file != nil { - // 切换上传策略 - fs.Policy = file.GetPolicy() - err := fs.DispatchHandler() - if err != nil { - util.Log().Warning("Failed to compress file %q: %s", file.Name, err) - return - } - - // 获取文件内容 - fileToZip, err := fs.Handler.Get( - context.WithValue(ctx, fsctx.FileModelCtx, *file), - file.SourceName, - ) - if err != nil { - util.Log().Debug("Failed to open %q: %s", file.Name, err) - return - } - if closer, ok := fileToZip.(io.Closer); ok { - defer closer.Close() - } - - // 创建压缩文件头 - header := &zip.FileHeader{ - Name: filepath.FromSlash(path.Join(file.Position, file.Name)), - Modified: file.UpdatedAt, - UncompressedSize64: file.Size, - } - - // 指定是压缩还是归档 - if isArchive { - header.Method = zip.Store - } else { - header.Method = zip.Deflate - } - - writer, err := zipWriter.CreateHeader(header) - if err != nil { - return - } - - _, err = io.Copy(writer, fileToZip) - } else if folder != nil { - // 对象是目录 - // 获取子文件 - subFiles, err := folder.GetChildFiles() - if err == nil && len(subFiles) > 0 { - for i := 0; i < len(subFiles); i++ { - fs.doCompress(ctx, &subFiles[i], nil, zipWriter, isArchive) - } - - } - // 获取子目录,继续递归遍历 - subFolders, err := folder.GetChildFolder() - if err == nil && len(subFolders) > 0 { - for i := 0; i < len(subFolders); i++ { - fs.doCompress(ctx, nil, &subFolders[i], zipWriter, isArchive) - } - } - } -} - -// Decompress 解压缩给定压缩文件到dst目录 -func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error { - err := fs.ResetFileIfNotExist(ctx, src) - if err != nil { - return err - } - - tempZipFilePath := "" - defer func() { - // 结束时删除临时压缩文件 - if tempZipFilePath != "" { - if err := os.Remove(tempZipFilePath); err != nil { - util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err) - } - } - }() - - // 下载压缩文件到临时目录 - fileStream, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName) - if err != nil { - return err - } - - defer fileStream.Close() - - tempZipFilePath = filepath.Join( - util.RelativePath(model.GetSettingByName("temp_path")), - "decompress", - fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()), - ) - - zipFile, err := util.CreatNestedFile(tempZipFilePath) - if err != nil { - util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err) - tempZipFilePath = "" - return err - } - defer zipFile.Close() - - // 下载前先判断是否是可解压的格式 - format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream) - if err != nil { - util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err) - return err - } - - extractor, ok := format.(archiver.Extractor) - if !ok { - return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName) - } - - // 只有zip格式可以多个文件同时上传 - var isZip bool - switch extractor.(type) { - case archiver.Zip: - extractor = archiver.Zip{TextEncoding: encoding} - isZip = true - } - - // 除了zip必须下载到本地,其余的可以边下载边解压 - reader := readStream - if isZip { - _, err = io.Copy(zipFile, readStream) - if err != nil { - util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err) - return err - } - - fileStream.Close() - - // 设置文件偏移量 - zipFile.Seek(0, io.SeekStart) - reader = zipFile - } - - // 重设存储策略 - fs.Policy = &fs.User.Policy - err = fs.DispatchHandler() - if err != nil { - return err - } - - var wg sync.WaitGroup - parallel := model.GetIntSetting("max_parallel_transfer", 4) - worker := make(chan int, parallel) - for i := 0; i < parallel; i++ { - worker <- i - } - - // 上传文件函数 - uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) { - defer func() { - if isZip { - worker <- 1 - wg.Done() - } - if err := recover(); err != nil { - util.Log().Warning("Error while uploading files inside of archive file.") - fmt.Println(err) - } - }() - - err := fs.UploadFromStream(ctx, &fsctx.FileStream{ - File: fileStream, - Size: uint64(size), - Name: path.Base(savePath), - VirtualPath: path.Dir(savePath), - }, true) - fileStream.Close() - if err != nil { - util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err) - } - } - - // 解压缩文件,回调函数如果出错会停止解压的下一步进行,全部return nil - err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error { - rawPath := util.FormSlash(f.NameInArchive) - savePath := path.Join(dst, rawPath) - // 路径是否合法 - if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) { - util.Log().Warning("%s: illegal file path", f.NameInArchive) - return nil - } - - // 如果是目录 - if f.FileInfo.IsDir() { - fs.CreateDirectory(ctx, savePath) - return nil - } - - // 上传文件 - fileStream, err := f.Open() - if err != nil { - util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err) - return nil - } - - if !isZip { - uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath) - } else { - <-worker - wg.Add(1) - go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath) - } - return nil - }) - wg.Wait() - return err - -} diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go deleted file mode 100644 index 07f50875..00000000 --- a/pkg/filesystem/archive_test.go +++ /dev/null @@ -1,256 +0,0 @@ -package filesystem - -import ( - "bytes" - "context" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - testMock "github.com/stretchr/testify/mock" - "io" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_Compress(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - - // 成功 - { - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - // 查找父目录子文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"})) - // 查找子目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "sub")) - // 查找子目录子文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}). - AddRow(2, "2.txt", "tests/file2.txt", 1), - ) - // 查找上传策略 - asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1)) - w := &bytes.Buffer{} - - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.NoError(err) - asserts.NotEmpty(w.Len()) - } - - // 上下文取消 - { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - - w := &bytes.Buffer{} - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.Error(err) - asserts.NotEmpty(w.Len()) - } - - // 限制父目录 - { - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{ - Model: gorm.Model{ID: 3}, - }) - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(1, "parent", 3)) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - - w := &bytes.Buffer{} - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.Error(err) - asserts.Equal(ErrObjectNotExist, err) - asserts.Empty(w.Len()) - } - -} - -type MockNopRSC string - -func (m MockNopRSC) Read(b []byte) (int, error) { - return 0, errors.New("read error") -} - -func (m MockNopRSC) Seek(n int64, offset int) (int64, error) { - return 0, errors.New("read error") -} - -func (m MockNopRSC) Close() error { - return errors.New("read error") -} - -type MockRSC struct { - rs io.ReadSeeker -} - -func (m MockRSC) Read(b []byte) (int, error) { - return m.rs.Read(b) -} - -func (m MockRSC) Seek(n int64, offset int) (int64, error) { - return m.rs.Seek(n, offset) -} - -func (m MockRSC) Close() error { - return nil -} - -var basepath string - -func init() { - _, currentFile, _, _ := runtime.Caller(0) - basepath = filepath.Dir(currentFile) -} - -func Path(rel string) string { - return filepath.Join(basepath, rel) -} - -func TestFileSystem_Decompress(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - os.RemoveAll(util.RelativePath("tests/decompress")) - - // 压缩文件不存在 - { - // 查找根目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "/")) - // 查找压缩文件,未找到 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 无法下载压缩文件 - { - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, errors.New("error")) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualError(err, "error") - } - - // 无法创建临时压缩文件 - { - cache.Set("setting_temp_path", "/tests:", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 无法写入压缩文件 - { - cache.Set("setting_temp_path", "tests", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Contains(err.Error(), "read error") - } - - // 无法重设上传策略 - { - cache.Set("setting_temp_path", "tests", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.True(util.IsEmpty(util.RelativePath("tests/decompress"))) - } - - // 无法上传,容量不足 - { - cache.Set("setting_max_parallel_transfer", "1", 0) - zipFile, _ := os.Open(Path("tests/test.zip")) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - fs.User.Policy.Type = "mock" - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil) - fs.Handler = testHandler - - fs.Decompress(ctx, "/1.zip", "/", "") - - zipFile.Close() - - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - } -} diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go deleted file mode 100644 index 0fda5347..00000000 --- a/pkg/filesystem/chunk/backoff/backoff_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package backoff - -import ( - "errors" - "github.com/stretchr/testify/assert" - "net/http" - "testing" - "time" -) - -func TestConstantBackoff_Next(t *testing.T) { - a := assert.New(t) - - // General error - { - err := errors.New("error") - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - b.Reset() - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - } - - // Retryable error - { - err := &RetryableError{RetryAfter: time.Duration(1)} - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - b.Reset() - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - } - -} - -func TestNewRetryableErrorFromHeader(t *testing.T) { - a := assert.New(t) - // no retry-after header - { - err := NewRetryableErrorFromHeader(nil, http.Header{}) - a.Empty(err.RetryAfter) - } - - // with retry-after header - { - header := http.Header{} - header.Add("retry-after", "120") - err := NewRetryableErrorFromHeader(nil, header) - a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter) - } -} diff --git a/pkg/filesystem/chunk/chunk_test.go b/pkg/filesystem/chunk/chunk_test.go deleted file mode 100644 index 4bdcd06d..00000000 --- a/pkg/filesystem/chunk/chunk_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package chunk - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/stretchr/testify/assert" - "io" - "os" - "strings" - "testing" -) - -func TestNewChunkGroup(t *testing.T) { - a := assert.New(t) - - testCases := []struct { - fileSize uint64 - chunkSize uint64 - expectedInnerChunkSize uint64 - expectedChunkNum uint64 - expectedInfo [][2]int //Start, Index,Length - }{ - {10, 0, 10, 1, [][2]int{{0, 10}}}, - {0, 0, 0, 1, [][2]int{{0, 0}}}, - {0, 10, 10, 1, [][2]int{{0, 0}}}, - {50, 10, 10, 5, [][2]int{ - {0, 10}, - {10, 10}, - {20, 10}, - {30, 10}, - {40, 10}, - }}, - {50, 50, 50, 1, [][2]int{ - {0, 50}, - }}, - - {50, 15, 15, 4, [][2]int{ - {0, 15}, - {15, 15}, - {30, 15}, - {45, 5}, - }}, - } - - for index, testCase := range testCases { - file := &fsctx.FileStream{Size: testCase.fileSize} - chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true) - a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), - "TestCase:%d,ChunkNum()", index) - a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize, - "TestCase:%d,InnerChunkSize()", index) - a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), - "TestCase:%d,len(Chunks)", index) - a.EqualValues(testCase.fileSize, chunkGroup.Total()) - - for cIndex, info := range testCase.expectedInfo { - a.True(chunkGroup.Next()) - a.EqualValues(info[1], chunkGroup.Length(), - "TestCase:%d,Chunks[%d].Length()", index, cIndex) - a.EqualValues(info[0], chunkGroup.Start(), - "TestCase:%d,Chunks[%d].Start()", index, cIndex) - - a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(), - "TestCase:%d,Chunks[%d].IsLast()", index, cIndex) - - a.NotEmpty(chunkGroup.RangeHeader()) - } - a.False(chunkGroup.Next()) - } -} - -func TestChunkGroup_TempAvailablet(t *testing.T) { - a := assert.New(t) - - file := &fsctx.FileStream{Size: 1} - c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true) - a.False(c.TempAvailable()) - - f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*") - defer func() { - f.Close() - os.Remove(f.Name()) - }() - a.NoError(err) - c.bufferTemp = f - - a.False(c.TempAvailable()) - f.Write([]byte("1")) - a.True(c.TempAvailable()) - -} - -func TestChunkGroup_Process(t *testing.T) { - a := assert.New(t) - file := &fsctx.FileStream{Size: 10} - - // success - { - file.File = io.NopCloser(strings.NewReader("1234567890")) - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - return nil - })) - a.False(c.Next()) - a.Equal(2, count) - } - - // retry, read from buffer file - { - file.File = io.NopCloser(strings.NewReader("1234567890")) - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(3, count) - } - - // retry, read from seeker - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(3, count) - } - - // retry, seek error - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - f.Close() - a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(2, count) - } - - // retry, finally error - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - return errors.New("error") - })) - a.False(c.Next()) - a.Equal(4, count) - } -} diff --git a/pkg/filesystem/driver/cos/handler.go b/pkg/filesystem/driver/cos/handler.go deleted file mode 100644 index 50b500c5..00000000 --- a/pkg/filesystem/driver/cos/handler.go +++ /dev/null @@ -1,427 +0,0 @@ -package cos - -import ( - "context" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/google/go-querystring/query" - cossdk "github.com/tencentyun/cos-go-sdk-v5" -) - -// UploadPolicy 腾讯云COS上传策略 -type UploadPolicy struct { - Expiration string `json:"expiration"` - Conditions []interface{} `json:"conditions"` -} - -// MetaData 文件元信息 -type MetaData struct { - Size uint64 - CallbackKey string - CallbackURL string -} - -type urlOption struct { - Speed int `url:"x-cos-traffic-limit,omitempty"` - ContentDescription string `url:"response-content-disposition,omitempty"` -} - -// Driver 腾讯云COS适配器模板 -type Driver struct { - Policy *model.Policy - Client *cossdk.Client - HTTPClient request.Client -} - -// List 列出COS文件 -func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - // 初始化列目录参数 - opt := &cossdk.BucketGetOptions{ - Prefix: strings.TrimPrefix(base, "/"), - EncodingType: "", - MaxKeys: 1000, - } - // 是否为递归列出 - if !recursive { - opt.Delimiter = "/" - } - // 手动补齐结尾的slash - if opt.Prefix != "" { - opt.Prefix += "/" - } - - var ( - marker string - objects []cossdk.Object - commons []string - ) - - for { - res, _, err := handler.Client.Bucket.Get(ctx, opt) - if err != nil { - return nil, err - } - objects = append(objects, res.Contents...) - commons = append(commons, res.CommonPrefixes...) - // 如果本次未列取完,则继续使用marker获取结果 - marker = res.NextMarker - // marker 为空时结果列取完毕,跳出 - if marker == "" { - break - } - } - - // 处理列取结果 - res := make([]response.Object, 0, len(objects)+len(commons)) - // 处理目录 - for _, object := range commons { - rel, err := filepath.Rel(opt.Prefix, object) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object), - RelativePath: filepath.ToSlash(rel), - Size: 0, - IsDir: true, - LastModify: time.Now(), - }) - } - // 处理文件 - for _, object := range objects { - rel, err := filepath.Rel(opt.Prefix, object.Key) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object.Key), - Source: object.Key, - RelativePath: filepath.ToSlash(rel), - Size: uint64(object.Size), - IsDir: false, - LastModify: time.Now(), - }) - } - - return res, nil - -} - -// CORS 创建跨域策略 -func (handler Driver) CORS() error { - _, err := handler.Client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{ - Rules: []cossdk.BucketCORSRule{{ - AllowedMethods: []string{ - "GET", - "POST", - "PUT", - "DELETE", - "HEAD", - }, - AllowedOrigins: []string{"*"}, - AllowedHeaders: []string{"*"}, - MaxAgeSeconds: 3600, - ExposeHeaders: []string{}, - }}, - }) - - return err -} - -// Get 获取文件 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - // 获取文件数据流 - resp, err := handler.HTTPClient.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - opt := &cossdk.ObjectPutOptions{} - _, err := handler.Client.Object.Put(ctx, file.Info().SavePath, file, opt) - return err -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { - obs := []cossdk.Object{} - for _, v := range files { - obs = append(obs, cossdk.Object{Key: v}) - } - opt := &cossdk.ObjectDeleteMultiOptions{ - Objects: obs, - Quiet: true, - } - - res, _, err := handler.Client.Object.DeleteMulti(context.Background(), opt) - if err != nil { - return files, err - } - - // 整理删除结果 - failed := make([]string, 0, len(files)) - for _, v := range res.Errors { - failed = append(failed, v.Key) - } - - if len(failed) == 0 { - return failed, nil - } - - return failed, errors.New("delete failed") -} - -// Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // quick check by extension name - // https://cloud.tencent.com/document/product/436/44893 - supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"} - if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 { - supported = handler.Policy.OptionsSerialized.ThumbExts - } - - if !util.IsInExtensionList(supported, file.Name) || file.Size > (32<<(10*2)) { - return nil, driver.ErrorThumbNotSupported - } - - var ( - thumbSize = [2]uint{400, 300} - ok = false - ) - if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok { - return nil, errors.New("failed to get thumbnail size") - } - - thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85) - - thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality) - - source, err := handler.signSourceURL( - ctx, - file.SourceName, - int64(model.GetIntSetting("preview_timeout", 60)), - &urlOption{}, - ) - if err != nil { - return nil, err - } - - thumbURL, _ := url.Parse(source) - thumbQuery := thumbURL.Query() - thumbQuery.Add(thumbParam, "") - thumbURL.RawQuery = thumbQuery.Encode() - - return &response.ContentResponse{ - Redirect: true, - URL: thumbURL.String(), - }, nil -} - -// Source 获取外链URL -func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - // 尝试从上下文获取文件名 - fileName := "" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - // 添加各项设置 - options := urlOption{} - if speed > 0 { - if speed < 819200 { - speed = 819200 - } - if speed > 838860800 { - speed = 838860800 - } - options.Speed = speed - } - if isDownload { - options.ContentDescription = "attachment; filename=\"" + url.PathEscape(fileName) + "\"" - } - - return handler.signSourceURL(ctx, path, ttl, &options) -} - -func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, options *urlOption) (string, error) { - cdnURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - - // 公有空间不需要签名 - if !handler.Policy.IsPrivate { - file, err := url.Parse(path) - if err != nil { - return "", err - } - - // 非签名URL不支持设置响应header - options.ContentDescription = "" - - optionQuery, err := query.Values(*options) - if err != nil { - return "", err - } - file.RawQuery = optionQuery.Encode() - sourceURL := cdnURL.ResolveReference(file) - - return sourceURL.String(), nil - } - - presignedURL, err := handler.Client.Object.GetPresignedURL(ctx, http.MethodGet, path, - handler.Policy.AccessKey, handler.Policy.SecretKey, time.Duration(ttl)*time.Second, options) - if err != nil { - return "", err - } - - // 将最终生成的签名URL域名换成用户自定义的加速域名(如果有) - presignedURL.Host = cdnURL.Host - presignedURL.Scheme = cdnURL.Scheme - - return presignedURL.String(), nil -} - -// Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 生成回调地址 - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key) - apiURL := siteURL.ResolveReference(apiBaseURI).String() - - // 上传策略 - savePath := file.Info().SavePath - startTime := time.Now() - endTime := startTime.Add(time.Duration(ttl) * time.Second) - keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix()) - postPolicy := UploadPolicy{ - Expiration: endTime.UTC().Format(time.RFC3339), - Conditions: []interface{}{ - map[string]string{"bucket": handler.Policy.BucketName}, - map[string]string{"$key": savePath}, - map[string]string{"x-cos-meta-callback": apiURL}, - map[string]string{"x-cos-meta-key": uploadSession.Key}, - map[string]string{"q-sign-algorithm": "sha1"}, - map[string]string{"q-ak": handler.Policy.AccessKey}, - map[string]string{"q-sign-time": keyTime}, - }, - } - - if handler.Policy.MaxSize > 0 { - postPolicy.Conditions = append(postPolicy.Conditions, - []interface{}{"content-length-range", 0, handler.Policy.MaxSize}) - } - - res, err := handler.getUploadCredential(ctx, postPolicy, keyTime, savePath) - if err == nil { - res.SessionID = uploadSession.Key - res.Callback = apiURL - res.UploadURLs = []string{handler.Policy.Server} - } - - return res, err - -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil -} - -// Meta 获取文件信息 -func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) { - res, err := handler.Client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{}) - if err != nil { - return nil, err - } - return &MetaData{ - Size: uint64(res.ContentLength), - CallbackKey: res.Header.Get("x-cos-meta-key"), - CallbackURL: res.Header.Get("x-cos-meta-callback"), - }, nil -} - -func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string, savePath string) (*serializer.UploadCredential, error) { - // 编码上传策略 - policyJSON, err := json.Marshal(policy) - if err != nil { - return nil, err - } - policyEncoded := base64.StdEncoding.EncodeToString(policyJSON) - - // 签名上传策略 - hmacSign := hmac.New(sha1.New, []byte(handler.Policy.SecretKey)) - _, err = io.WriteString(hmacSign, keyTime) - if err != nil { - return nil, err - } - signKey := fmt.Sprintf("%x", hmacSign.Sum(nil)) - - sha1Sign := sha1.New() - _, err = sha1Sign.Write(policyJSON) - if err != nil { - return nil, err - } - stringToSign := fmt.Sprintf("%x", sha1Sign.Sum(nil)) - - // 最终签名 - hmacFinalSign := hmac.New(sha1.New, []byte(signKey)) - _, err = hmacFinalSign.Write([]byte(stringToSign)) - if err != nil { - return nil, err - } - signature := hmacFinalSign.Sum(nil) - - return &serializer.UploadCredential{ - Policy: policyEncoded, - Path: savePath, - AccessKey: handler.Policy.AccessKey, - Credential: fmt.Sprintf("%x", signature), - KeyTime: keyTime, - }, nil -} diff --git a/pkg/filesystem/driver/cos/scf.go b/pkg/filesystem/driver/cos/scf.go deleted file mode 100644 index 9ddb29c1..00000000 --- a/pkg/filesystem/driver/cos/scf.go +++ /dev/null @@ -1,134 +0,0 @@ -package cos - -import ( - "archive/zip" - "bytes" - "encoding/base64" - "io" - "io/ioutil" - "net/url" - "strconv" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" - scf "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf/v20180416" -) - -const scfFunc = `# -*- coding: utf8 -*- -# SCF配置COS触发,向 Cloudreve 发送回调 -from qcloud_cos_v5 import CosConfig -from qcloud_cos_v5 import CosS3Client -from qcloud_cos_v5 import CosServiceError -from qcloud_cos_v5 import CosClientError -import sys -import logging -import requests - -logging.basicConfig(level=logging.INFO, stream=sys.stdout) -logger = logging.getLogger() - - -def main_handler(event, context): - logger.info("start main handler") - for record in event['Records']: - try: - if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']: - logger.info("Cannot find callback URL, skiped.") - return 'Success' - callback = record['cos']['cosObject']['meta']['x-cos-meta-callback'] - key = record['cos']['cosObject']['key'] - logger.info("Callback URL is " + callback) - - r = requests.get(callback) - print(r.text) - - - - except Exception as e: - print(e) - print('Error getting object {} callback url. '.format(key)) - raise e - return "Fail" - - return "Success" -` - -// CreateSCF 创建回调云函数 -func CreateSCF(policy *model.Policy, region string) error { - // 初始化客户端 - credential := common.NewCredential( - policy.AccessKey, - policy.SecretKey, - ) - cpf := profile.NewClientProfile() - client, err := scf.NewClient(credential, region, cpf) - if err != nil { - return err - } - - // 创建回调代码数据 - buff := &bytes.Buffer{} - bs64 := base64.NewEncoder(base64.StdEncoding, buff) - zipWriter := zip.NewWriter(bs64) - header := zip.FileHeader{ - Name: "callback.py", - Method: zip.Deflate, - } - writer, err := zipWriter.CreateHeader(&header) - if err != nil { - return err - } - _, err = io.Copy(writer, strings.NewReader(scfFunc)) - zipWriter.Close() - - // 创建云函数 - req := scf.NewCreateFunctionRequest() - funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10) - zipFileBytes, _ := ioutil.ReadAll(buff) - zipFileStr := string(zipFileBytes) - codeSource := "ZipFile" - handler := "callback.main_handler" - desc := "Cloudreve 用回调函数" - timeout := int64(60) - runtime := "Python3.6" - req.FunctionName = &funcName - req.Code = &scf.Code{ - ZipFile: &zipFileStr, - } - req.Handler = &handler - req.Description = &desc - req.Timeout = &timeout - req.Runtime = &runtime - req.CodeSource = &codeSource - - _, err = client.CreateFunction(req) - if err != nil { - return err - } - - time.Sleep(time.Duration(5) * time.Second) - - // 创建触发器 - server, _ := url.Parse(policy.Server) - triggerType := "cos" - triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}` - enable := "OPEN" - - trigger := scf.NewCreateTriggerRequest() - trigger.FunctionName = &funcName - trigger.TriggerName = &server.Host - trigger.Type = &triggerType - trigger.TriggerDesc = &triggerDesc - trigger.Enable = &enable - - _, err = client.CreateTrigger(trigger) - if err != nil { - return err - } - - return nil -} diff --git a/pkg/filesystem/driver/googledrive/client.go b/pkg/filesystem/driver/googledrive/client.go deleted file mode 100644 index de37257f..00000000 --- a/pkg/filesystem/driver/googledrive/client.go +++ /dev/null @@ -1,73 +0,0 @@ -package googledrive - -import ( - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "google.golang.org/api/drive/v3" -) - -// Client Google Drive client -type Client struct { - Endpoints *Endpoints - Policy *model.Policy - Credential *Credential - - ClientID string - ClientSecret string - Redirect string - - Request request.Client - ClusterController cluster.Controller -} - -// Endpoints OneDrive客户端相关设置 -type Endpoints struct { - UserConsentEndpoint string // OAuth认证的基URL - TokenEndpoint string // OAuth token 基URL - EndpointURL string // 接口请求的基URL -} - -const ( - TokenCachePrefix = "googledrive_" - - oauthEndpoint = "https://oauth2.googleapis.com/token" - userConsentBase = "https://accounts.google.com/o/oauth2/auth" - v3DriveEndpoint = "https://www.googleapis.com/drive/v3" -) - -var ( - // Defualt required scopes - RequiredScope = []string{ - drive.DriveScope, - "openid", - "profile", - "https://www.googleapis.com/auth/userinfo.profile", - } - - // ErrInvalidRefreshToken 上传策略无有效的RefreshToken - ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy") -) - -// NewClient 根据存储策略获取新的client -func NewClient(policy *model.Policy) (*Client, error) { - client := &Client{ - Endpoints: &Endpoints{ - TokenEndpoint: oauthEndpoint, - UserConsentEndpoint: userConsentBase, - EndpointURL: v3DriveEndpoint, - }, - Credential: &Credential{ - RefreshToken: policy.AccessKey, - }, - Policy: policy, - ClientID: policy.BucketName, - ClientSecret: policy.SecretKey, - Redirect: policy.OptionsSerialized.OauthRedirect, - Request: request.NewClient(), - ClusterController: cluster.DefaultController, - } - - return client, nil -} diff --git a/pkg/filesystem/driver/googledrive/handler.go b/pkg/filesystem/driver/googledrive/handler.go deleted file mode 100644 index 917ae872..00000000 --- a/pkg/filesystem/driver/googledrive/handler.go +++ /dev/null @@ -1,65 +0,0 @@ -package googledrive - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Driver Google Drive 适配器 -type Driver struct { - Policy *model.Policy - HTTPClient request.Client -} - -// NewDriver 从存储策略初始化新的Driver实例 -func NewDriver(policy *model.Policy) (driver.Handler, error) { - return &Driver{ - Policy: policy, - HTTPClient: request.NewClient(), - }, nil -} - -func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - //TODO implement me - panic("implement me") -} - -func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - //TODO implement me - panic("implement me") -} - -func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - //TODO implement me - panic("implement me") -} - -func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - //TODO implement me - panic("implement me") -} - -func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - //TODO implement me - panic("implement me") -} - -func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - //TODO implement me - panic("implement me") -} - -func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - //TODO implement me - panic("implement me") -} - -func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - //TODO implement me - panic("implement me") -} diff --git a/pkg/filesystem/driver/googledrive/oauth.go b/pkg/filesystem/driver/googledrive/oauth.go deleted file mode 100644 index da8a80a1..00000000 --- a/pkg/filesystem/driver/googledrive/oauth.go +++ /dev/null @@ -1,154 +0,0 @@ -package googledrive - -import ( - "context" - "encoding/json" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io" - "net/http" - "net/url" - "strings" - "time" -) - -// OAuthURL 获取OAuth认证页面URL -func (client *Client) OAuthURL(ctx context.Context, scope []string) string { - query := url.Values{ - "client_id": {client.ClientID}, - "scope": {strings.Join(scope, " ")}, - "response_type": {"code"}, - "redirect_uri": {client.Redirect}, - "access_type": {"offline"}, - "prompt": {"consent"}, - } - - u, _ := url.Parse(client.Endpoints.UserConsentEndpoint) - u.RawQuery = query.Encode() - return u.String() -} - -// ObtainToken 通过code或refresh_token兑换token -func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) { - body := url.Values{ - "client_id": {client.ClientID}, - "redirect_uri": {client.Redirect}, - "client_secret": {client.ClientSecret}, - } - if code != "" { - body.Add("grant_type", "authorization_code") - body.Add("code", code) - } else { - body.Add("grant_type", "refresh_token") - body.Add("refresh_token", refreshToken) - } - strBody := body.Encode() - - res := client.Request.Request( - "POST", - client.Endpoints.TokenEndpoint, - io.NopCloser(strings.NewReader(strBody)), - request.WithHeader(http.Header{ - "Content-Type": {"application/x-www-form-urlencoded"}}, - ), - request.WithContentLength(int64(len(strBody))), - ) - if res.Err != nil { - return nil, res.Err - } - - respBody, err := res.GetResponse() - if err != nil { - return nil, err - } - - var ( - errResp OAuthError - credential Credential - decodeErr error - ) - - if res.Response.StatusCode != 200 { - decodeErr = json.Unmarshal([]byte(respBody), &errResp) - } else { - decodeErr = json.Unmarshal([]byte(respBody), &credential) - } - if decodeErr != nil { - return nil, decodeErr - } - - if errResp.ErrorType != "" { - return nil, errResp - } - - return &credential, nil -} - -// UpdateCredential 更新凭证,并检查有效期 -func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error { - if isSlave { - return client.fetchCredentialFromMaster(ctx) - } - - oauth.GlobalMutex.Lock(client.Policy.ID) - defer oauth.GlobalMutex.Unlock(client.Policy.ID) - - // 如果已存在凭证 - if client.Credential != nil && client.Credential.AccessToken != "" { - // 检查已有凭证是否过期 - if client.Credential.ExpiresIn > time.Now().Unix() { - // 未过期,不要更新 - return nil - } - } - - // 尝试从缓存中获取凭证 - if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok { - credential := cacheCredential.(Credential) - if credential.ExpiresIn > time.Now().Unix() { - client.Credential = &credential - return nil - } - } - - // 获取新的凭证 - if client.Credential == nil || client.Credential.RefreshToken == "" { - // 无有效的RefreshToken - util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name) - return ErrInvalidRefreshToken - } - - credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken) - if err != nil { - return err - } - - // 更新有效期为绝对时间戳 - expires := credential.ExpiresIn - 60 - credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix() - // refresh token for Google Drive does not expire in production - credential.RefreshToken = client.Credential.RefreshToken - client.Credential = credential - - // 更新缓存 - cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires)) - - return nil -} - -func (client *Client) AccessToken() string { - return client.Credential.AccessToken -} - -// UpdateCredential 更新凭证,并检查有效期 -func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { - res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) - if err != nil { - return err - } - - client.Credential = &Credential{AccessToken: res} - return nil -} diff --git a/pkg/filesystem/driver/googledrive/types.go b/pkg/filesystem/driver/googledrive/types.go deleted file mode 100644 index a459c155..00000000 --- a/pkg/filesystem/driver/googledrive/types.go +++ /dev/null @@ -1,43 +0,0 @@ -package googledrive - -import "encoding/gob" - -// RespError 接口返回错误 -type RespError struct { - APIError APIError `json:"error"` -} - -// APIError 接口返回的错误内容 -type APIError struct { - Code string `json:"code"` - Message string `json:"message"` -} - -// Error 实现error接口 -func (err RespError) Error() string { - return err.APIError.Message -} - -// Credential 获取token时返回的凭证 -type Credential struct { - ExpiresIn int64 `json:"expires_in"` - Scope string `json:"scope"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - UserID string `json:"user_id"` -} - -// OAuthError OAuth相关接口的错误响应 -type OAuthError struct { - ErrorType string `json:"error"` - ErrorDescription string `json:"error_description"` -} - -// Error 实现error接口 -func (err OAuthError) Error() string { - return err.ErrorDescription -} - -func init() { - gob.Register(Credential{}) -} diff --git a/pkg/filesystem/driver/handler.go b/pkg/filesystem/driver/handler.go deleted file mode 100644 index f2327813..00000000 --- a/pkg/filesystem/driver/handler.go +++ /dev/null @@ -1,51 +0,0 @@ -package driver - -import ( - "context" - "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -var ( - ErrorThumbNotExist = fmt.Errorf("thumb not exist") - ErrorThumbNotSupported = fmt.Errorf("thumb not supported") -) - -// Handler 存储策略适配器 -type Handler interface { - // 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭 - // 时,应取消上传并清理临时文件 - Put(ctx context.Context, file fsctx.FileHeader) error - - // 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误 - Delete(ctx context.Context, files []string) ([]string, error) - - // 获取文件内容 - Get(ctx context.Context, path string) (response.RSCloser, error) - - // 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指 - // 定为重定向 - // 如果缩略图不存在, 且需要 Cloudreve 代理生成并上传,应返回 ErrorThumbNotExist,生 - // 成的缩略图文件存储规则与本机策略一致。 - // 如果不支持此文件的缩略图,并且不希望后续继续请求此缩略图,应返回 ErrorThumbNotSupported - Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) - - // 获取外链/下载地址, - // url - 站点本身地址, - // isDownload - 是否直接下载 - Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) - - // Token 获取有效期为ttl的上传凭证和签名 - Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) - - // CancelToken 取消已经创建的有状态上传凭证 - CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error - - // List 递归列取远程端path路径下文件、目录,不包含path本身, - // 返回的对象路径以path作为起始根目录. - // recursive - 是否递归列出 - List(ctx context.Context, path string, recursive bool) ([]response.Object, error) -} diff --git a/pkg/filesystem/driver/local/handler.go b/pkg/filesystem/driver/local/handler.go deleted file mode 100644 index 85ba1af5..00000000 --- a/pkg/filesystem/driver/local/handler.go +++ /dev/null @@ -1,292 +0,0 @@ -package local - -import ( - "context" - "errors" - "fmt" - "io" - "net/url" - "os" - "path/filepath" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -const ( - Perm = 0744 -) - -// Driver 本地策略适配器 -type Driver struct { - Policy *model.Policy -} - -// List 递归列取给定物理路径下所有文件 -func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - var res []response.Object - - // 取得起始路径 - root := util.RelativePath(filepath.FromSlash(path)) - - // 开始遍历路径下的文件、目录 - err := filepath.Walk(root, - func(path string, info os.FileInfo, err error) error { - // 跳过根目录 - if path == root { - return nil - } - - if err != nil { - util.Log().Warning("Failed to walk folder %q: %s", path, err) - return filepath.SkipDir - } - - // 将遍历对象的绝对路径转换为相对路径 - rel, err := filepath.Rel(root, path) - if err != nil { - return err - } - - res = append(res, response.Object{ - Name: info.Name(), - RelativePath: filepath.ToSlash(rel), - Source: path, - Size: uint64(info.Size()), - IsDir: info.IsDir(), - LastModify: info.ModTime(), - }) - - // 如果非递归,则不步入目录 - if !recursive && info.IsDir() { - return filepath.SkipDir - } - - return nil - }) - - return res, err -} - -// Get 获取文件内容 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 打开文件 - file, err := os.Open(util.RelativePath(path)) - if err != nil { - util.Log().Debug("Failed to open file: %s", err) - return nil, err - } - - return file, nil -} - -// Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - fileInfo := file.Info() - dst := util.RelativePath(filepath.FromSlash(fileInfo.SavePath)) - - // 如果非 Overwrite,则检查是否有重名冲突 - if fileInfo.Mode&fsctx.Overwrite != fsctx.Overwrite { - if util.Exists(dst) { - util.Log().Warning("File with the same name existed or unavailable: %s", dst) - return errors.New("file with the same name existed or unavailable") - } - } - - // 如果目标目录不存在,创建 - basePath := filepath.Dir(dst) - if !util.Exists(basePath) { - err := os.MkdirAll(basePath, Perm) - if err != nil { - util.Log().Warning("Failed to create directory: %s", err) - return err - } - } - - var ( - out *os.File - err error - ) - - openMode := os.O_CREATE | os.O_RDWR - if fileInfo.Mode&fsctx.Append == fsctx.Append { - openMode |= os.O_APPEND - } else { - openMode |= os.O_TRUNC - } - - out, err = os.OpenFile(dst, openMode, Perm) - if err != nil { - util.Log().Warning("Failed to open or create file: %s", err) - return err - } - defer out.Close() - - if fileInfo.Mode&fsctx.Append == fsctx.Append { - stat, err := out.Stat() - if err != nil { - util.Log().Warning("Failed to read file info: %s", err) - return err - } - - if uint64(stat.Size()) < fileInfo.AppendStart { - return errors.New("size of unfinished uploaded chunks is not as expected") - } else if uint64(stat.Size()) > fileInfo.AppendStart { - out.Close() - if err := handler.Truncate(ctx, dst, fileInfo.AppendStart); err != nil { - return fmt.Errorf("failed to overwrite chunk: %w", err) - } - - out, err = os.OpenFile(dst, openMode, Perm) - defer out.Close() - if err != nil { - util.Log().Warning("Failed to create or open file: %s", err) - return err - } - } - } - - // 写入文件内容 - _, err = io.Copy(out, file) - return err -} - -func (handler Driver) Truncate(ctx context.Context, src string, size uint64) error { - util.Log().Warning("Truncate file %q to [%d].", src, size) - out, err := os.OpenFile(src, os.O_WRONLY, Perm) - if err != nil { - util.Log().Warning("Failed to open file: %s", err) - return err - } - - defer out.Close() - return out.Truncate(int64(size)) -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { - deleteFailed := make([]string, 0, len(files)) - var retErr error - - for _, value := range files { - filePath := util.RelativePath(filepath.FromSlash(value)) - if util.Exists(filePath) { - err := os.Remove(filePath) - if err != nil { - util.Log().Warning("Failed to delete file: %s", err) - retErr = err - deleteFailed = append(deleteFailed, value) - } - } - - // 尝试删除文件的缩略图(如果有) - _ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"))) - } - - return deleteFailed, retErr -} - -// Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // Quick check thumb existence on master. - if conf.SystemConfig.Mode == "master" && file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusNotExist { - // Tell invoker to generate a thumb - return nil, driver.ErrorThumbNotExist - } - - thumbFile, err := handler.Get(ctx, file.ThumbFile()) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - err = fmt.Errorf("thumb not exist: %w (%w)", err, driver.ErrorThumbNotExist) - } - - return nil, err - } - - return &response.ContentResponse{ - Redirect: false, - Content: thumbFile, - }, nil -} - -// Source 获取外链URL -func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return "", errors.New("failed to read file model context") - } - - var baseURL *url.URL - // 是否启用了CDN - if handler.Policy.BaseURL != "" { - cdnURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - baseURL = cdnURL - } - - var ( - signedURI *url.URL - err error - ) - if isDownload { - // 创建下载会话,将文件信息写入缓存 - downloadSessionID := util.RandStringRunes(16) - err = cache.Set("download_"+downloadSessionID, file, int(ttl)) - if err != nil { - return "", serializer.NewError(serializer.CodeCacheOperation, "Failed to create download session", err) - } - - // 签名生成文件记录 - signedURI, err = auth.SignURI( - auth.General, - fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID), - ttl, - ) - } else { - // 签名生成文件记录 - signedURI, err = auth.SignURI( - auth.General, - fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name), - ttl, - ) - } - - if err != nil { - return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err) - } - - finalURL := signedURI.String() - if baseURL != nil { - finalURL = baseURL.ResolveReference(signedURI).String() - } - - return finalURL, nil -} - -// Token 获取上传策略和认证Token,本地策略直接返回空值 -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - if util.Exists(uploadSession.SavePath) { - return nil, errors.New("placeholder file already exist") - } - - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - }, nil -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil -} diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go deleted file mode 100644 index b73b5641..00000000 --- a/pkg/filesystem/driver/local/handler_test.go +++ /dev/null @@ -1,338 +0,0 @@ -package local - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "io" - "os" - "strings" - "testing" -) - -func TestHandler_Put(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - - defer func() { - os.Remove(util.RelativePath("TestHandler_Put.txt")) - os.Remove(util.RelativePath("inner/TestHandler_Put.txt")) - }() - - testCases := []struct { - file fsctx.FileHeader - errContains string - }{ - {&fsctx.FileStream{ - SavePath: "TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, ""}, - {&fsctx.FileStream{ - SavePath: "TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, "file with the same name existed or unavailable"}, - {&fsctx.FileStream{ - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, ""}, - {&fsctx.FileStream{ - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, ""}, - {&fsctx.FileStream{ - AppendStart: 10, - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, "size of unfinished uploaded chunks is not as expected"}, - {&fsctx.FileStream{ - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, ""}, - } - - for _, testCase := range testCases { - err := handler.Put(context.Background(), testCase.file) - if testCase.errContains != "" { - asserts.Error(err) - asserts.Contains(err.Error(), testCase.errContains) - } else { - asserts.NoError(err) - asserts.True(util.Exists(util.RelativePath(testCase.file.Info().SavePath))) - } - } -} - -func TestDriver_TruncateFailed(t *testing.T) { - a := assert.New(t) - h := Driver{} - a.Error(h.Truncate(context.Background(), "TestDriver_TruncateFailed", 0)) -} - -func TestHandler_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - filePath := util.RelativePath("TestHandler_Delete.file") - - file, err := os.Create(filePath) - asserts.NoError(err) - _ = file.Close() - list, err := handler.Delete(ctx, []string{"TestHandler_Delete.file"}) - asserts.Equal([]string{}, list) - asserts.NoError(err) - - file, err = os.Create(filePath) - _ = file.Close() - file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0)) - asserts.NoError(err) - list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file", "test.notexist"}) - file.Close() - asserts.Equal([]string{}, list) - asserts.NoError(err) - - list, err = handler.Delete(ctx, []string{"test.notexist"}) - asserts.Equal([]string{}, list) - asserts.NoError(err) - - file, err = os.Create(filePath) - asserts.NoError(err) - list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file"}) - _ = file.Close() - asserts.Equal([]string{}, list) - asserts.NoError(err) -} - -func TestHandler_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 成功 - file, err := os.Create(util.RelativePath("TestHandler_Get.txt")) - asserts.NoError(err) - _ = file.Close() - - rs, err := handler.Get(ctx, "TestHandler_Get.txt") - asserts.NoError(err) - asserts.NotNil(rs) - - // 文件不存在 - - rs, err = handler.Get(ctx, "TestHandler_Get_notExist.txt") - asserts.Error(err) - asserts.Nil(rs) -} - -func TestHandler_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - file, err := os.Create(util.RelativePath("TestHandler_Thumb._thumb")) - asserts.NoError(err) - file.Close() - - f := &model.File{ - SourceName: "TestHandler_Thumb", - MetadataSerialized: map[string]string{ - model.ThumbStatusMetadataKey: model.ThumbStatusExist, - }, - } - - // 正常 - { - thumb, err := handler.Thumb(ctx, f) - asserts.NoError(err) - asserts.NotNil(thumb.Content) - } - - // file 不存在 - { - f.SourceName = "not_exist" - _, err := handler.Thumb(ctx, f) - asserts.Error(err) - asserts.ErrorIs(err, driver.ErrorThumbNotExist) - } - - // thumb not exist - { - f.MetadataSerialized[model.ThumbStatusMetadataKey] = model.ThumbStatusNotExist - _, err := handler.Thumb(ctx, f) - asserts.Error(err) - asserts.ErrorIs(err, driver.ErrorThumbNotExist) - } -} - -func TestHandler_Source(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{}, - } - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 成功 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - } - - // 下载 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, true, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - asserts.Contains(sourceURL, "download") - } - - // 无法获取上下文 - { - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.Error(err) - asserts.Empty(sourceURL) - } - - // 设定了CDN - { - handler.Policy.BaseURL = "https://cqu.edu.cn" - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - asserts.Contains(sourceURL, "https://cqu.edu.cn") - } - - // 设定了CDN,解析失败 - { - handler.Policy.BaseURL = string([]byte{0x7f}) - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.Error(err) - asserts.Empty(sourceURL) - } -} - -func TestHandler_GetDownloadURL(t *testing.T) { - asserts := assert.New(t) - handler := Driver{Policy: &model.Policy{}} - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 成功 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - downloadURL, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(downloadURL, "sign=") - } - - // 无法获取上下文 - { - downloadURL, err := handler.Source(ctx, "", 10, true, 0) - asserts.Error(err) - asserts.Empty(downloadURL) - } -} - -func TestHandler_Token(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{}, - } - ctx := context.Background() - upSession := &serializer.UploadSession{SavePath: "TestHandler_Token"} - _, err := handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) - asserts.NoError(err) - - file, _ := os.Create("TestHandler_Token") - defer func() { - file.Close() - os.Remove("TestHandler_Token") - }() - - _, err = handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) - asserts.Error(err) - asserts.Contains(err.Error(), "already exist") -} - -func TestDriver_CancelToken(t *testing.T) { - a := assert.New(t) - handler := Driver{} - a.NoError(handler.CancelToken(context.Background(), &serializer.UploadSession{})) -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - - // 创建测试目录结构 - for _, path := range []string{ - "test/TestDriver_List/parent.txt", - "test/TestDriver_List/parent_folder2/sub2.txt", - "test/TestDriver_List/parent_folder1/sub_folder/sub1.txt", - "test/TestDriver_List/parent_folder1/sub_folder/sub2.txt", - } { - f, _ := util.CreatNestedFile(util.RelativePath(path)) - f.Close() - } - - // 非递归列出 - { - res, err := handler.List(ctx, "test/TestDriver_List", false) - asserts.NoError(err) - asserts.Len(res, 3) - } - - // 递归列出 - { - res, err := handler.List(ctx, "test/TestDriver_List", true) - asserts.NoError(err) - asserts.Len(res, 7) - } -} diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go deleted file mode 100644 index a675548b..00000000 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ /dev/null @@ -1,1155 +0,0 @@ -package onedrive - -import ( - "context" - "errors" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRequest(t *testing.T) { - asserts := assert.New(t) - client := Client{ - Policy: &model.Policy{}, - ClientID: "TestRequest", - Credential: &Credential{ - ExpiresIn: time.Now().Add(time.Duration(100) * time.Hour).Unix(), - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - }, - } - - // 请求发送失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - asserts.Equal("error", err.Error()) - } - - // 无法更新凭证 - { - client.Credential.RefreshToken = "" - client.Credential.AccessToken = "" - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - asserts.Error(err) - asserts.Empty(res) - client.Credential.RefreshToken = "RefreshToken" - client.Credential.AccessToken = "AccessToken" - } - - // 无法获取响应正文 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(mockReader("")), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // OneDrive返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - asserts.Equal("error msg", err.Error()) - } - - // OneDrive返回429错误 - { - header := http.Header{} - header.Add("retry-after", "120") - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 429, - Header: header, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - var retryErr *backoff.RetryableError - asserts.ErrorAs(err, &retryErr) - asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter) - } - - // OneDrive返回未知响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } -} - -func TestFileInfo_GetSourcePath(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - fileInfo := FileInfo{ - Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg", - ParentReference: parentReference{ - Path: "/drive/root:/123/32%201", - }, - } - asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath()) - } - - // 失败 - { - fileInfo := FileInfo{ - Name: "123.jpg", - ParentReference: parentReference{ - Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g", - }, - } - asserts.Equal("", fileInfo.GetSourcePath()) - } -} - -func TestClient_GetRequestURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - - // 出错 - { - client.Endpoints.EndpointURL = string([]byte{0x7f}) - asserts.Equal("", client.getRequestURL("123")) - } - - // 使用DriverResource - { - client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" - asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123")) - } - - // 不使用DriverResource - { - client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" - asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false))) - } -} - -func TestClient_GetSiteIDByURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotEmpty(res) - asserts.Equal("123321", res) - } -} - -func TestClient_Meta(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.Meta(context.Background(), "", "123") - asserts.Error(err) - asserts.Nil(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "", "123") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "", "123") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } - - // 返回正常, 使用资源id - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "123321", "123") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } -} - -func TestClient_CreateUploadSession(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.CreateUploadSession(context.Background(), "123.jpg") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.CreateUploadSession(context.Background(), "123.jpg") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.CreateUploadSession(context.Background(), "123.jpg", WithConflictBehavior("fail")) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res) - } -} - -func TestClient_GetUploadSessionStatus(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.UploadURL) - } -} - -func TestClient_UploadChunk(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cg := chunk.NewChunkGroup(&fsctx.FileStream{Size: 15}, 10, &backoff.ConstantBackoff{}, false) - - // 非最后分片,正常 - { - cg.Next() - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Equal("http://dev.com/2", res.UploadURL) - } - - // 非最后分片,异常响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 最后分片,正常 - { - cg.Next() - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Nil(res) - } - - // 最后分片,失败 - { - cache.Set("setting_chunk_retries", "1", 0) - client.Credential.ExpiresIn = 0 - go func() { - time.Sleep(time.Duration(2) * time.Second) - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - }() - clientMock := ClientMock{} - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } -} - -func TestClient_Upload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - ctx := context.Background() - cache.Set("setting_chunk_retries", "1", 0) - cache.Set("setting_use_temp_chunk_buffer", "false", 0) - - // 小文件,简单上传,失败 - { - client.Credential.ExpiresIn = 0 - err := client.Upload(ctx, &fsctx.FileStream{ - Size: 5, - File: io.NopCloser(strings.NewReader("12345")), - }) - asserts.Error(err) - } - - // 无法创建分片会话 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - err := client.Upload(context.Background(), &fsctx.FileStream{ - Size: SmallFileSize + 1, - File: io.NopCloser(strings.NewReader("12345")), - }) - clientMock.AssertExpectations(t) - asserts.Error(err) - } - - // 分片上传失败 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - err := client.Upload(context.Background(), &fsctx.FileStream{ - Size: SmallFileSize + 1, - File: io.NopCloser(strings.NewReader("12345")), - }) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Contains(err.Error(), "failed to upload chunk") - } - -} - -func TestClient_SimpleUpload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cache.Set("setting_chunk_retries", "1", 0) - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } -} - -func TestClient_DeleteUploadSession(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - err := client.DeleteUploadSession(context.Background(), "123.jpg") - asserts.Error(err) - - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 204, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, - }) - client.Request = clientMock - err := client.DeleteUploadSession(context.Background(), "123.jpg") - clientMock.AssertExpectations(t) - asserts.NoError(err) - } -} - -func TestClient_BatchDelete(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 小于20个,失败1个 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), - }, - }) - client.Request = clientMock - res, err := client.BatchDelete(context.Background(), []string{"1", "2", "3", "1", "2"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Equal([]string{"2"}, res) - } -} - -func TestClient_Delete(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - asserts.Error(err) - asserts.Len(res, 3) - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 3) - } - - // 成功2两个文件 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), - }, - }) - client.Request = clientMock - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Equal([]string{"2"}, res) - } -} - -func TestClient_ListChildren(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 根目录,请求失败,重测试 - { - client.Credential.ExpiresIn = 0 - res, err := client.ListChildren(context.Background(), "/") - asserts.Error(err) - asserts.Empty(res) - } - - // 非根目录,未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.ListChildren(context.Background(), "/uploads") - asserts.Error(err) - asserts.Empty(res) - } - - // 非根目录,成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)), - }, - }) - client.Request = clientMock - res, err := client.ListChildren(context.Background(), "/uploads") - asserts.NoError(err) - asserts.Len(res, 1) - } -} - -func TestClient_GetThumbURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.Error(err) - asserts.Empty(res) - } - - // 未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.Error(err) - asserts.Empty(res) - } - - // 世纪互联 成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Endpoints.isInChina = true - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"url":"thumb"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.NoError(err) - asserts.Equal("thumb", res) - } - - // 非世纪互联 成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Endpoints.isInChina = false - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"large":{"url":"thumb"}}]}`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.NoError(err) - asserts.Equal("thumb", res) - } -} - -func TestClient_MonitorUpload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 客户端完成回调 - { - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - asserts.NotPanics(func() { - go func() { - time.Sleep(time.Duration(1) * time.Second) - mq.GlobalMQ.Publish("key", mq.Message{}) - }() - client.MonitorUpload("url", "key", "path", 10, 10) - }) - } - - // 上传会话到期,仍未完成上传,创建占位符 - { - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - asserts.NotPanics(func() { - client.MonitorUpload("url", "key", "path", 10, 0) - }) - } - - fmt.Println("测试:上传已完成,未发送回调") - // 上传已完成,未发送回调 - { - cache.Set("setting_onedrive_monitor_timeout", "0", 0) - cache.Set("setting_onedrive_callback_check", "0", 0) - - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), - }, - }) - client.Request = clientMock - cache.Set("callback_key3", "ok", 0) - - asserts.NotPanics(func() { - client.MonitorUpload("url", "key3", "path", 10, 10) - }) - - clientMock.AssertExpectations(t) - } - - fmt.Println("测试:上传仍未开始") - // 上传仍未开始 - { - cache.Set("setting_onedrive_monitor_timeout", "0", 0) - cache.Set("setting_onedrive_callback_check", "0", 0) - - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"nextExpectedRanges":["0-"]}`)), - }, - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, - }) - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{}`)), - }, - }) - client.Request = clientMock - - asserts.NotPanics(func() { - client.MonitorUpload("url", "key4", "path", 10, 10) - }) - - clientMock.AssertExpectations(t) - } - -} diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go deleted file mode 100644 index 957af8e7..00000000 --- a/pkg/filesystem/driver/onedrive/client.go +++ /dev/null @@ -1,77 +0,0 @@ -package onedrive - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/request" -) - -var ( - // ErrAuthEndpoint 无法解析授权端点地址 - ErrAuthEndpoint = errors.New("failed to parse endpoint url") - // ErrInvalidRefreshToken 上传策略无有效的RefreshToken - ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy") - // ErrDeleteFile 无法删除文件 - ErrDeleteFile = errors.New("cannot delete file") - // ErrClientCanceled 客户端取消操作 - ErrClientCanceled = errors.New("client canceled") - // Desired thumb size not available - ErrThumbSizeNotFound = errors.New("thumb size not found") -) - -// Client OneDrive客户端 -type Client struct { - Endpoints *Endpoints - Policy *model.Policy - Credential *Credential - - ClientID string - ClientSecret string - Redirect string - - Request request.Client - ClusterController cluster.Controller -} - -// Endpoints OneDrive客户端相关设置 -type Endpoints struct { - OAuthURL string // OAuth认证的基URL - OAuthEndpoints *oauthEndpoint - EndpointURL string // 接口请求的基URL - isInChina bool // 是否为世纪互联 - DriverResource string // 要使用的驱动器 -} - -// NewClient 根据存储策略获取新的client -func NewClient(policy *model.Policy) (*Client, error) { - client := &Client{ - Endpoints: &Endpoints{ - OAuthURL: policy.BaseURL, - EndpointURL: policy.Server, - DriverResource: policy.OptionsSerialized.OdDriver, - }, - Credential: &Credential{ - RefreshToken: policy.AccessKey, - }, - Policy: policy, - ClientID: policy.BucketName, - ClientSecret: policy.SecretKey, - Redirect: policy.OptionsSerialized.OauthRedirect, - Request: request.NewClient(), - ClusterController: cluster.DefaultController, - } - - if client.Endpoints.DriverResource == "" { - client.Endpoints.DriverResource = "me/drive" - } - - oauthBase := client.getOAuthEndpoint() - if oauthBase == nil { - return nil, ErrAuthEndpoint - } - client.Endpoints.OAuthEndpoints = oauthBase - - return client, nil -} diff --git a/pkg/filesystem/driver/onedrive/client_test.go b/pkg/filesystem/driver/onedrive/client_test.go deleted file mode 100644 index aa3c1320..00000000 --- a/pkg/filesystem/driver/onedrive/client_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package onedrive - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -func TestNewClient(t *testing.T) { - asserts := assert.New(t) - // getOAuthEndpoint失败 - { - policy := model.Policy{ - BaseURL: string([]byte{0x7f}), - } - res, err := NewClient(&policy) - asserts.Error(err) - asserts.Nil(res) - } - - // 成功 - { - policy := model.Policy{} - res, err := NewClient(&policy) - asserts.NoError(err) - asserts.NotNil(res) - asserts.NotNil(res.Credential) - asserts.NotNil(res.Endpoints) - asserts.NotNil(res.Endpoints.OAuthEndpoints) - } -} diff --git a/pkg/filesystem/driver/onedrive/handler.go b/pkg/filesystem/driver/onedrive/handler.go deleted file mode 100644 index 149fdba8..00000000 --- a/pkg/filesystem/driver/onedrive/handler.go +++ /dev/null @@ -1,238 +0,0 @@ -package onedrive - -import ( - "context" - "errors" - "fmt" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Driver OneDrive 适配器 -type Driver struct { - Policy *model.Policy - Client *Client - HTTPClient request.Client -} - -// NewDriver 从存储策略初始化新的Driver实例 -func NewDriver(policy *model.Policy) (driver.Handler, error) { - client, err := NewClient(policy) - if policy.OptionsSerialized.ChunkSize == 0 { - policy.OptionsSerialized.ChunkSize = 50 << 20 // 50MB - } - - return Driver{ - Policy: policy, - Client: client, - HTTPClient: request.NewClient(), - }, err -} - -// List 列取项目 -func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - base = strings.TrimPrefix(base, "/") - // 列取子项目 - objects, _ := handler.Client.ListChildren(ctx, base) - - // 获取真实的列取起始根目录 - rootPath := base - if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok { - rootPath = realBase - } else { - ctx = context.WithValue(ctx, fsctx.PathCtx, base) - } - - // 整理结果 - res := make([]response.Object, 0, len(objects)) - for _, object := range objects { - source := path.Join(base, object.Name) - rel, err := filepath.Rel(rootPath, source) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: object.Name, - RelativePath: filepath.ToSlash(rel), - Source: source, - Size: object.Size, - IsDir: object.Folder != nil, - LastModify: time.Now(), - }) - } - - // 递归列取子目录 - if recursive { - for _, object := range objects { - if object.Folder != nil { - sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive) - res = append(res, sub...) - } - } - } - - return res, nil -} - -// Get 获取文件 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 获取文件源地址 - downloadURL, err := handler.Source( - ctx, - path, - 60, - false, - 0, - ) - if err != nil { - return nil, err - } - - // 获取文件数据流 - resp, err := handler.HTTPClient.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - return handler.Client.Upload(ctx, file) -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { - return handler.Client.BatchDelete(ctx, files) -} - -// Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - var ( - thumbSize = [2]uint{400, 300} - ok = false - ) - if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok { - return nil, errors.New("failed to get thumbnail size") - } - - res, err := handler.Client.GetThumbURL(ctx, file.SourceName, thumbSize[0], thumbSize[1]) - if err != nil { - var apiErr *RespError - if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) { - // OneDrive cannot generate thumbnail for this file - return nil, driver.ErrorThumbNotSupported - } - } - - return &response.ContentResponse{ - Redirect: true, - URL: res, - }, err -} - -// Source 获取外链URL -func (handler Driver) Source( - ctx context.Context, - path string, - ttl int64, - isDownload bool, - speed int, -) (string, error) { - cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path) - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID) - } - - // 尝试从缓存中查找 - if cachedURL, ok := cache.Get(cacheKey); ok { - return handler.replaceSourceHost(cachedURL.(string)) - } - - // 缓存不存在,重新获取 - res, err := handler.Client.Meta(ctx, "", path) - if err == nil { - // 写入新的缓存 - cache.Set( - cacheKey, - res.DownloadURL, - model.GetIntSetting("onedrive_source_timeout", 1800), - ) - return handler.replaceSourceHost(res.DownloadURL) - } - return "", err -} - -func (handler Driver) replaceSourceHost(origin string) (string, error) { - if handler.Policy.OptionsSerialized.OdProxy != "" { - source, err := url.Parse(origin) - if err != nil { - return "", err - } - - cdn, err := url.Parse(handler.Policy.OptionsSerialized.OdProxy) - if err != nil { - return "", err - } - - // 替换反代地址 - source.Scheme = cdn.Scheme - source.Host = cdn.Host - return source.String(), nil - } - - return origin, nil -} - -// Token 获取上传会话URL -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - fileInfo := file.Info() - - uploadURL, err := handler.Client.CreateUploadSession(ctx, fileInfo.SavePath, WithConflictBehavior("fail")) - if err != nil { - return nil, err - } - - // 监控回调及上传 - go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl) - - uploadSession.UploadURL = uploadURL - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - UploadURLs: []string{uploadURL}, - }, nil -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return handler.Client.DeleteUploadSession(ctx, uploadSession.UploadURL) -} diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go deleted file mode 100644 index 2c9c2c26..00000000 --- a/pkg/filesystem/driver/onedrive/handler_test.go +++ /dev/null @@ -1,420 +0,0 @@ -package onedrive - -import ( - "context" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/jinzhu/gorm" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestDriver_Token(t *testing.T) { - asserts := assert.New(t) - h, _ := NewDriver(&model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }) - handler := h.(Driver) - - // 分片上传 失败 - { - cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - asserts.Error(err) - asserts.Nil(res) - } - - // 分片上传 成功 - { - cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - go func() { - time.Sleep(time.Duration(1) * time.Second) - mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{}) - }() - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{}) - asserts.NoError(err) - asserts.Equal("123321", res.UploadURLs[0]) - } -} - -func TestDriver_Source(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cache.Set("setting_onedrive_source_timeout", "1800", 0) - - // 失败 - { - res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0) - asserts.Error(err) - asserts.Empty(res) - } - - // 命中缓存 成功 - { - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - cache.Set("onedrive_source_0_123.jpg", "res", 1) - res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0) - cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") - asserts.NoError(err) - asserts.Equal("res", res) - } - - // 命中缓存 上下文存在文件 成功 - { - file := model.File{} - file.ID = 1 - file.UpdatedAt = time.Now() - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0) - res, err := handler.Source(ctx, "123.jpg", 1, true, 0) - cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") - asserts.NoError(err) - asserts.Equal("res", res) - } - - // 成功 - { - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - handler.Client.Credential.AccessToken = "1" - res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0) - asserts.NoError(err) - asserts.Equal("123321", res) - } -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.AccessToken = "AccessToken" - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 非递归 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.List(context.Background(), "/", false) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 递归一次 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "me/drive/root/children?$top=999999999", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)), - }, - }) - clientMock.On( - "Request", - "GET", - "me/drive/root:/1:/children?$top=999999999", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.List(context.Background(), "/", true) - asserts.NoError(err) - asserts.Len(res, 2) - } -} - -func TestDriver_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - file := &model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}} - - // 失败 - { - ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20}) - res, err := handler.Thumb(ctx, file) - asserts.Error(err) - asserts.Empty(res.URL) - } - - // 上下文错误 - { - _, err := handler.Thumb(context.Background(), file) - asserts.Error(err) - } -} - -func TestDriver_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - _, err := handler.Delete(context.Background(), []string{"1"}) - asserts.Error(err) - } - -} - -func TestDriver_Put(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - err := handler.Put(context.Background(), &fsctx.FileStream{}) - asserts.Error(err) - } -} - -func TestDriver_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 无法获取source - { - res, err := handler.Get(context.Background(), "123.txt") - asserts.Error(err) - asserts.Nil(res) - } - - // 成功 - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - handler.Client.Credential.AccessToken = "1" - - driverClientMock := ClientMock{} - driverClientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`123`)), - }, - }) - handler.HTTPClient = driverClientMock - res, err := handler.Get(context.Background(), "123.txt") - clientMock.AssertExpectations(t) - asserts.NoError(err) - _, err = res.Seek(0, io.SeekEnd) - asserts.NoError(err) - content, err := ioutil.ReadAll(res) - asserts.NoError(err) - asserts.Equal("123", string(content)) -} - -func TestDriver_replaceSourceHost(t *testing.T) { - tests := []struct { - name string - origin string - cdn string - want string - wantErr bool - }{ - {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false}, - {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false}, - {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true}, - {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - policy := &model.Policy{} - policy.OptionsSerialized.OdProxy = tt.cdn - handler := Driver{ - Policy: policy, - } - got, err := handler.replaceSourceHost(tt.origin) - if (err != nil) != tt.wantErr { - t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestDriver_CancelToken(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - err := handler.CancelToken(context.Background(), &serializer.UploadSession{}) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go deleted file mode 100644 index bb00005f..00000000 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ /dev/null @@ -1,192 +0,0 @@ -package onedrive - -import ( - "context" - "encoding/json" - "io/ioutil" - "net/http" - "net/url" - "strings" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// Error 实现error接口 -func (err OAuthError) Error() string { - return err.ErrorDescription -} - -// OAuthURL 获取OAuth认证页面URL -func (client *Client) OAuthURL(ctx context.Context, scope []string) string { - query := url.Values{ - "client_id": {client.ClientID}, - "scope": {strings.Join(scope, " ")}, - "response_type": {"code"}, - "redirect_uri": {client.Redirect}, - } - client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode() - return client.Endpoints.OAuthEndpoints.authorize.String() -} - -// getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址 -func (client *Client) getOAuthEndpoint() *oauthEndpoint { - base, err := url.Parse(client.Endpoints.OAuthURL) - if err != nil { - return nil - } - var ( - token *url.URL - authorize *url.URL - ) - switch base.Host { - case "login.live.com": - token, _ = url.Parse("https://login.live.com/oauth20_token.srf") - authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf") - case "login.chinacloudapi.cn": - client.Endpoints.isInChina = true - token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token") - authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize") - default: - token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token") - authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize") - } - - return &oauthEndpoint{ - token: *token, - authorize: *authorize, - } -} - -// ObtainToken 通过code或refresh_token兑换token -func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) { - options := newDefaultOption() - for _, o := range opts { - o.apply(options) - } - - body := url.Values{ - "client_id": {client.ClientID}, - "redirect_uri": {client.Redirect}, - "client_secret": {client.ClientSecret}, - } - if options.code != "" { - body.Add("grant_type", "authorization_code") - body.Add("code", options.code) - } else { - body.Add("grant_type", "refresh_token") - body.Add("refresh_token", options.refreshToken) - } - strBody := body.Encode() - - res := client.Request.Request( - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - ioutil.NopCloser(strings.NewReader(strBody)), - request.WithHeader(http.Header{ - "Content-Type": {"application/x-www-form-urlencoded"}}, - ), - request.WithContentLength(int64(len(strBody))), - ) - if res.Err != nil { - return nil, res.Err - } - - respBody, err := res.GetResponse() - if err != nil { - return nil, err - } - - var ( - errResp OAuthError - credential Credential - decodeErr error - ) - - if res.Response.StatusCode != 200 { - decodeErr = json.Unmarshal([]byte(respBody), &errResp) - } else { - decodeErr = json.Unmarshal([]byte(respBody), &credential) - } - if decodeErr != nil { - return nil, decodeErr - } - - if errResp.ErrorType != "" { - return nil, errResp - } - - return &credential, nil - -} - -// UpdateCredential 更新凭证,并检查有效期 -func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error { - if isSlave { - return client.fetchCredentialFromMaster(ctx) - } - - oauth.GlobalMutex.Lock(client.Policy.ID) - defer oauth.GlobalMutex.Unlock(client.Policy.ID) - - // 如果已存在凭证 - if client.Credential != nil && client.Credential.AccessToken != "" { - // 检查已有凭证是否过期 - if client.Credential.ExpiresIn > time.Now().Unix() { - // 未过期,不要更新 - return nil - } - } - - // 尝试从缓存中获取凭证 - if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok { - credential := cacheCredential.(Credential) - if credential.ExpiresIn > time.Now().Unix() { - client.Credential = &credential - return nil - } - } - - // 获取新的凭证 - if client.Credential == nil || client.Credential.RefreshToken == "" { - // 无有效的RefreshToken - util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name) - return ErrInvalidRefreshToken - } - - credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken)) - if err != nil { - return err - } - - // 更新有效期为绝对时间戳 - expires := credential.ExpiresIn - 60 - credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix() - client.Credential = credential - - // 更新存储策略的 RefreshToken - client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken) - - // 更新缓存 - cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) - - return nil -} - -func (client *Client) AccessToken() string { - return client.Credential.AccessToken -} - -// UpdateCredential 更新凭证,并检查有效期 -func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { - res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) - if err != nil { - return err - } - - client.Credential = &Credential{AccessToken: res} - return nil -} diff --git a/pkg/filesystem/driver/onedrive/oauth_test.go b/pkg/filesystem/driver/onedrive/oauth_test.go deleted file mode 100644 index b2525b7c..00000000 --- a/pkg/filesystem/driver/onedrive/oauth_test.go +++ /dev/null @@ -1,386 +0,0 @@ -package onedrive - -import ( - "context" - "database/sql" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" - "io" - "io/ioutil" - "net/http" - "net/url" - "strings" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestGetOAuthEndpoint(t *testing.T) { - asserts := assert.New(t) - - // URL解析失败 - { - client := Client{ - Endpoints: &Endpoints{ - OAuthURL: string([]byte{0x7f}), - }, - } - res := client.getOAuthEndpoint() - asserts.Nil(res) - } - - { - testCase := []struct { - OAuthURL string - token string - auth string - isChina bool - }{ - { - OAuthURL: "http://login.live.com", - token: "https://login.live.com/oauth20_token.srf", - auth: "https://login.live.com/oauth20_authorize.srf", - isChina: false, - }, - { - OAuthURL: "http://login.chinacloudapi.cn", - token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token", - auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize", - isChina: true, - }, - { - OAuthURL: "other", - token: "https://login.microsoftonline.com/common/oauth2/v2.0/token", - auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", - isChina: false, - }, - } - - for i, testCase := range testCase { - client := Client{ - Endpoints: &Endpoints{ - OAuthURL: testCase.OAuthURL, - }, - } - res := client.getOAuthEndpoint() - asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i) - asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i) - asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i) - } - } -} - -func TestClient_OAuthURL(t *testing.T) { - asserts := assert.New(t) - - client := Client{ - ClientID: "client_id", - Redirect: "http://cloudreve.org/callback", - Endpoints: &Endpoints{}, - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"})) - asserts.NoError(err) - query := res.Query() - asserts.Equal("client_id", query.Get("client_id")) - asserts.Equal("scope1 scope2", query.Get("scope")) - asserts.Equal(client.Redirect, query.Get("redirect_uri")) - -} - -type ClientMock struct { - testMock.Mock -} - -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - args := m.Called(method, target, body, opts) - return args.Get(0).(*request.Response) -} - -type mockReader string - -func (r mockReader) Read(b []byte) (int, error) { - return 0, errors.New("read error") -} - -func TestClient_ObtainToken(t *testing.T) { - asserts := assert.New(t) - - client := Client{ - Endpoints: &Endpoints{}, - ClientID: "ClientID", - ClientSecret: "ClientSecret", - Redirect: "Redirect", - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - - // 刷新Token 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("i am token", res.AccessToken) - } - - // 重新获取 无法发送请求 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background(), WithCode("code")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 刷新Token 无法获取响应正文 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(mockReader("")), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - asserts.Equal("read error", err.Error()) - } - - // 刷新Token OneDrive返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - asserts.Equal("", err.Error()) - } - - // 刷新Token OneDrive未知响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } -} - -func TestClient_UpdateCredential(t *testing.T) { - asserts := assert.New(t) - client := Client{ - Policy: &model.Policy{Model: gorm.Model{ID: 257}}, - Endpoints: &Endpoints{}, - ClientID: "TestClient_UpdateCredential", - ClientSecret: "ClientSecret", - Redirect: "Redirect", - Credential: &Credential{}, - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - - // 无有效的RefreshToken - { - err := client.UpdateCredential(context.Background(), false) - asserts.Equal(ErrInvalidRefreshToken, err) - client.Credential = nil - err = client.UpdateCredential(context.Background(), false) - asserts.Equal(ErrInvalidRefreshToken, err) - } - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)), - }, - }) - client.Request = clientMock - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := client.UpdateCredential(context.Background(), false) - clientMock.AssertExpectations(t) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential") - asserts.True(ok) - cacheCredential := cacheRes.(Credential) - asserts.Equal("new_refresh_token", cacheCredential.RefreshToken) - asserts.Equal("i am token", cacheCredential.AccessToken) - } - - // OneDrive返回错误 - { - cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_") - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)), - }, - }) - client.Request = clientMock - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - err := client.UpdateCredential(context.Background(), false) - clientMock.AssertExpectations(t) - asserts.Error(err) - } - - // 从缓存中获取 - { - cache.Set("onedrive_TestClient_UpdateCredential", Credential{ - ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - }, 0) - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - err := client.UpdateCredential(context.Background(), false) - asserts.NoError(err) - asserts.Equal("AccessToken", client.Credential.AccessToken) - asserts.Equal("RefreshToken", client.Credential.RefreshToken) - } - - // 无需重新获取 - { - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - AccessToken: "AccessToken2", - ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), - } - err := client.UpdateCredential(context.Background(), false) - asserts.NoError(err) - asserts.Equal("AccessToken2", client.Credential.AccessToken) - } - - // slave failed - { - mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error")) - client.ClusterController = mockController - err := client.UpdateCredential(context.Background(), true) - asserts.Error(err) - } - - // slave success - { - mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil) - client.ClusterController = mockController - err := client.UpdateCredential(context.Background(), true) - asserts.NoError(err) - asserts.Equal("AccessToken3", client.Credential.AccessToken) - } -} diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go deleted file mode 100644 index 2ae50a37..00000000 --- a/pkg/filesystem/driver/oss/handler.go +++ /dev/null @@ -1,491 +0,0 @@ -package oss - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - "github.com/HFO4/aliyun-oss-go-sdk/oss" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// UploadPolicy 阿里云OSS上传策略 -type UploadPolicy struct { - Expiration string `json:"expiration"` - Conditions []interface{} `json:"conditions"` -} - -// CallbackPolicy 回调策略 -type CallbackPolicy struct { - CallbackURL string `json:"callbackUrl"` - CallbackBody string `json:"callbackBody"` - CallbackBodyType string `json:"callbackBodyType"` -} - -// Driver 阿里云OSS策略适配器 -type Driver struct { - Policy *model.Policy - client *oss.Client - bucket *oss.Bucket - HTTPClient request.Client -} - -type key int - -const ( - chunkRetrySleep = time.Duration(5) * time.Second - - // MultiPartUploadThreshold 服务端使用分片上传的阈值 - MultiPartUploadThreshold uint64 = 5 * (1 << 30) // 5GB - // VersionID 文件版本标识 - VersionID key = iota -) - -func NewDriver(policy *model.Policy) (*Driver, error) { - if policy.OptionsSerialized.ChunkSize == 0 { - policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB - } - - driver := &Driver{ - Policy: policy, - HTTPClient: request.NewClient(), - } - - return driver, driver.InitOSSClient(false) -} - -// CORS 创建跨域策略 -func (handler *Driver) CORS() error { - return handler.client.SetBucketCORS(handler.Policy.BucketName, []oss.CORSRule{ - { - AllowedOrigin: []string{"*"}, - AllowedMethod: []string{ - "GET", - "POST", - "PUT", - "DELETE", - "HEAD", - }, - ExposeHeader: []string{}, - AllowedHeader: []string{"*"}, - MaxAgeSeconds: 3600, - }, - }) -} - -// InitOSSClient 初始化OSS鉴权客户端 -func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error { - if handler.Policy == nil { - return errors.New("empty policy") - } - - // 决定是否使用内网 Endpoint - endpoint := handler.Policy.Server - if handler.Policy.OptionsSerialized.ServerSideEndpoint != "" && !forceUsePublicEndpoint { - endpoint = handler.Policy.OptionsSerialized.ServerSideEndpoint - } - - // 初始化客户端 - client, err := oss.New(endpoint, handler.Policy.AccessKey, handler.Policy.SecretKey) - if err != nil { - return err - } - handler.client = client - - // 初始化存储桶 - bucket, err := client.Bucket(handler.Policy.BucketName) - if err != nil { - return err - } - handler.bucket = bucket - - return nil -} - -// List 列出OSS上的文件 -func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - // 列取文件 - base = strings.TrimPrefix(base, "/") - if base != "" { - base += "/" - } - - var ( - delimiter string - marker string - objects []oss.ObjectProperties - commons []string - ) - if !recursive { - delimiter = "/" - } - - for { - subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base), - oss.MaxKeys(1000), oss.Delimiter(delimiter)) - if err != nil { - return nil, err - } - objects = append(objects, subRes.Objects...) - commons = append(commons, subRes.CommonPrefixes...) - marker = subRes.NextMarker - if marker == "" { - break - } - } - - // 处理列取结果 - res := make([]response.Object, 0, len(objects)+len(commons)) - // 处理目录 - for _, object := range commons { - rel, err := filepath.Rel(base, object) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object), - RelativePath: filepath.ToSlash(rel), - Size: 0, - IsDir: true, - LastModify: time.Now(), - }) - } - // 处理文件 - for _, object := range objects { - rel, err := filepath.Rel(base, object.Key) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object.Key), - Source: object.Key, - RelativePath: filepath.ToSlash(rel), - Size: uint64(object.Size), - IsDir: false, - LastModify: object.LastModified, - }) - } - - return res, nil -} - -// Get 获取文件 -func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 通过VersionID禁止缓存 - ctx = context.WithValue(ctx, VersionID, time.Now().UnixNano()) - - // 尽可能使用私有 Endpoint - ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false) - - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - // 获取文件数据流 - resp, err := handler.HTTPClient.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - fileInfo := file.Info() - - // 凭证有效期 - credentialTTL := model.GetIntSetting("upload_session_timeout", 3600) - - // 是否允许覆盖 - overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite - options := []oss.Option{ - oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)), - oss.ForbidOverWrite(!overwrite), - } - - // 小文件直接上传 - if fileInfo.Size < MultiPartUploadThreshold { - return handler.bucket.PutObject(fileInfo.SavePath, file, options...) - } - - // 超过阈值时使用分片上传 - imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...) - if err != nil { - return fmt.Errorf("failed to initiate multipart upload: %w", err) - } - - chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{ - Max: model.GetIntSetting("chunk_retries", 5), - Sleep: chunkRetrySleep, - }, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer"))) - - uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { - _, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1) - return err - } - - for chunks.Next() { - if err := chunks.Process(uploadFunc); err != nil { - return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) - } - } - - _, err = handler.bucket.CompleteMultipartUpload(imur, oss.CompleteAll("yes"), oss.ForbidOverWrite(!overwrite)) - return err -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件 -func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - // 删除文件 - delRes, err := handler.bucket.DeleteObjects(files) - - if err != nil { - return files, err - } - - // 统计未删除的文件 - failed := util.SliceDifference(files, delRes.DeletedObjects) - if len(failed) > 0 { - return failed, errors.New("failed to delete") - } - - return []string{}, nil -} - -// Thumb 获取文件缩略图 -func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // quick check by extension name - // https://help.aliyun.com/document_detail/183902.html - supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"} - if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 { - supported = handler.Policy.OptionsSerialized.ThumbExts - } - - if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) { - return nil, driver.ErrorThumbNotSupported - } - - // 初始化客户端 - if err := handler.InitOSSClient(true); err != nil { - return nil, err - } - - var ( - thumbSize = [2]uint{400, 300} - ok = false - ) - if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok { - return nil, errors.New("failed to get thumbnail size") - } - - thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85) - - thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d/quality,q_%d", thumbSize[1], thumbSize[0], thumbEncodeQuality) - ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, thumbParam) - thumbOption := []oss.Option{oss.Process(thumbParam)} - thumbURL, err := handler.signSourceURL( - ctx, - file.SourceName, - int64(model.GetIntSetting("preview_timeout", 60)), - thumbOption, - ) - if err != nil { - return nil, err - } - - return &response.ContentResponse{ - Redirect: true, - URL: thumbURL, - }, nil -} - -// Source 获取外链URL -func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - // 初始化客户端 - usePublicEndpoint := true - if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok { - usePublicEndpoint = forceUsePublicEndpoint - } - if err := handler.InitOSSClient(usePublicEndpoint); err != nil { - return "", err - } - - // 尝试从上下文获取文件名 - fileName := "" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - // 添加各项设置 - var signOptions = make([]oss.Option, 0, 2) - if isDownload { - signOptions = append(signOptions, oss.ResponseContentDisposition("attachment; filename=\""+url.PathEscape(fileName)+"\"")) - } - if speed > 0 { - // Byte 转换为 bit - speed *= 8 - - // OSS对速度值有范围限制 - if speed < 819200 { - speed = 819200 - } - if speed > 838860800 { - speed = 838860800 - } - signOptions = append(signOptions, oss.TrafficLimitParam(int64(speed))) - } - - return handler.signSourceURL(ctx, path, ttl, signOptions) -} - -func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64, options []oss.Option) (string, error) { - signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...) - if err != nil { - return "", err - } - - // 将最终生成的签名URL域名换成用户自定义的加速域名(如果有) - finalURL, err := url.Parse(signedURL) - if err != nil { - return "", err - } - - // 公有空间替换掉Key及不支持的头 - if !handler.Policy.IsPrivate { - query := finalURL.Query() - query.Del("OSSAccessKeyId") - query.Del("Signature") - query.Del("response-content-disposition") - query.Del("x-oss-traffic-limit") - finalURL.RawQuery = query.Encode() - } - - if handler.Policy.BaseURL != "" { - cdnURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - finalURL.Host = cdnURL.Host - finalURL.Scheme = cdnURL.Scheme - } - - return finalURL.String(), nil -} - -// Token 获取上传策略和认证Token -func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 初始化客户端 - if err := handler.InitOSSClient(true); err != nil { - return nil, err - } - - // 生成回调地址 - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/oss/" + uploadSession.Key) - apiURL := siteURL.ResolveReference(apiBaseURI) - - // 回调策略 - callbackPolicy := CallbackPolicy{ - CallbackURL: apiURL.String(), - CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`, - CallbackBodyType: "application/json", - } - callbackPolicyJSON, err := json.Marshal(callbackPolicy) - if err != nil { - return nil, fmt.Errorf("failed to encode callback policy: %w", err) - } - callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON) - - // 初始化分片上传 - fileInfo := file.Info() - options := []oss.Option{ - oss.Expires(time.Now().Add(time.Duration(ttl) * time.Second)), - oss.ForbidOverWrite(true), - oss.ContentType(fileInfo.DetectMimeType()), - } - imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...) - if err != nil { - return nil, fmt.Errorf("failed to initialize multipart upload: %w", err) - } - uploadSession.UploadID = imur.UploadID - - // 为每个分片签名上传 URL - chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false) - urls := make([]string, chunks.Num()) - for chunks.Next() { - err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { - signedURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPut, ttl, - oss.PartNumber(c.Index()+1), - oss.UploadID(imur.UploadID), - oss.ContentType("application/octet-stream")) - if err != nil { - return err - } - - urls[c.Index()] = signedURL - return nil - }) - if err != nil { - return nil, err - } - } - - // 签名完成分片上传的URL - completeURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPost, ttl, - oss.ContentType("application/octet-stream"), - oss.UploadID(imur.UploadID), - oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)), - oss.CompleteAll("yes"), - oss.ForbidOverWrite(true), - oss.CallbackParam(callbackPolicyEncoded)) - if err != nil { - return nil, err - } - - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - UploadID: imur.UploadID, - UploadURLs: urls, - CompleteURL: completeURL, - }, nil -} - -// 取消上传凭证 -func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.SavePath}, nil) -} diff --git a/pkg/filesystem/driver/qiniu/handler.go b/pkg/filesystem/driver/qiniu/handler.go deleted file mode 100644 index a11b5740..00000000 --- a/pkg/filesystem/driver/qiniu/handler.go +++ /dev/null @@ -1,354 +0,0 @@ -package qiniu - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "net/http" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/qiniu/go-sdk/v7/auth/qbox" - "github.com/qiniu/go-sdk/v7/storage" -) - -// Driver 本地策略适配器 -type Driver struct { - Policy *model.Policy - mac *qbox.Mac - cfg *storage.Config - bucket *storage.BucketManager -} - -func NewDriver(policy *model.Policy) *Driver { - if policy.OptionsSerialized.ChunkSize == 0 { - policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB - } - - mac := qbox.NewMac(policy.AccessKey, policy.SecretKey) - cfg := &storage.Config{UseHTTPS: true} - return &Driver{ - Policy: policy, - mac: mac, - cfg: cfg, - bucket: storage.NewBucketManager(mac, cfg), - } -} - -// List 列出给定路径下的文件 -func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - base = strings.TrimPrefix(base, "/") - if base != "" { - base += "/" - } - - var ( - delimiter string - marker string - objects []storage.ListItem - commons []string - ) - if !recursive { - delimiter = "/" - } - - for { - entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles( - handler.Policy.BucketName, - base, delimiter, marker, 1000) - if err != nil { - return nil, err - } - objects = append(objects, entries...) - commons = append(commons, folders...) - if !hashNext { - break - } - marker = nextMarker - } - - // 处理列取结果 - res := make([]response.Object, 0, len(objects)+len(commons)) - // 处理目录 - for _, object := range commons { - rel, err := filepath.Rel(base, object) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object), - RelativePath: filepath.ToSlash(rel), - Size: 0, - IsDir: true, - LastModify: time.Now(), - }) - } - // 处理文件 - for _, object := range objects { - rel, err := filepath.Rel(base, object.Key) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(object.Key), - Source: object.Key, - RelativePath: filepath.ToSlash(rel), - Size: uint64(object.Fsize), - IsDir: false, - LastModify: time.Unix(object.PutTime/10000000, 0), - }) - } - - return res, nil -} - -// Get 获取文件 -func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 给文件名加上随机参数以强制拉取 - path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano()) - - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - // 获取文件数据流 - client := request.NewClient() - resp, err := client.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithHeader( - http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}}, - ), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - // 凭证有效期 - credentialTTL := model.GetIntSetting("upload_session_timeout", 3600) - - // 生成上传策略 - fileInfo := file.Info() - scope := handler.Policy.BucketName - if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite { - scope = fmt.Sprintf("%s:%s", handler.Policy.BucketName, fileInfo.SavePath) - } - - putPolicy := storage.PutPolicy{ - // 指定为覆盖策略 - Scope: scope, - SaveKey: fileInfo.SavePath, - ForceSaveKey: true, - FsizeLimit: int64(fileInfo.Size), - } - // 是否开启了MIMEType限制 - if handler.Policy.OptionsSerialized.MimeType != "" { - putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType - } - - // 生成上传凭证 - token, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, int64(credentialTTL), false) - if err != nil { - return err - } - - // 创建上传表单 - cfg := storage.Config{} - formUploader := storage.NewFormUploader(&cfg) - ret := storage.PutRet{} - putExtra := storage.PutExtra{ - Params: map[string]string{}, - } - - // 开始上传 - err = formUploader.Put(ctx, &ret, token.Credential, fileInfo.SavePath, file, int64(fileInfo.Size), &putExtra) - if err != nil { - return err - } - - return nil -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件 -func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - // TODO 大于一千个文件需要分批发送 - deleteOps := make([]string, 0, len(files)) - for _, key := range files { - deleteOps = append(deleteOps, storage.URIDelete(handler.Policy.BucketName, key)) - } - - rets, err := handler.bucket.Batch(deleteOps) - - // 处理删除结果 - if err != nil { - failed := make([]string, 0, len(rets)) - for k, ret := range rets { - if ret.Code != 200 && ret.Code != 612 { - failed = append(failed, files[k]) - } - } - return failed, errors.New("删除失败") - } - - return []string{}, nil -} - -// Thumb 获取文件缩略图 -func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // quick check by extension name - // https://developer.qiniu.com/dora/api/basic-processing-images-imageview2 - supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"} - if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 { - supported = handler.Policy.OptionsSerialized.ThumbExts - } - - if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) { - return nil, driver.ErrorThumbNotSupported - } - - var ( - thumbSize = [2]uint{400, 300} - ok = false - ) - if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok { - return nil, errors.New("failed to get thumbnail size") - } - - thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85) - - thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d/q/%d", file.SourceName, thumbSize[0], thumbSize[1], thumbEncodeQuality) - return &response.ContentResponse{ - Redirect: true, - URL: handler.signSourceURL( - ctx, - thumb, - int64(model.GetIntSetting("preview_timeout", 60)), - ), - }, nil -} - -// Source 获取外链URL -func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - // 尝试从上下文获取文件名 - fileName := "" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - // 加入下载相关设置 - if isDownload { - path = path + "?attname=" + url.PathEscape(fileName) - } - - // 取得原始文件地址 - return handler.signSourceURL(ctx, path, ttl), nil -} - -func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64) string { - var sourceURL string - if handler.Policy.IsPrivate { - deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix() - sourceURL = storage.MakePrivateURL(handler.mac, handler.Policy.BaseURL, path, deadline) - } else { - sourceURL = storage.MakePublicURL(handler.Policy.BaseURL, path) - } - return sourceURL -} - -// Token 获取上传策略和认证Token -func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 生成回调地址 - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/qiniu/" + uploadSession.Key) - apiURL := siteURL.ResolveReference(apiBaseURI) - - // 创建上传策略 - fileInfo := file.Info() - putPolicy := storage.PutPolicy{ - Scope: handler.Policy.BucketName, - CallbackURL: apiURL.String(), - CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`, - CallbackBodyType: "application/json", - SaveKey: fileInfo.SavePath, - ForceSaveKey: true, - FsizeLimit: int64(handler.Policy.MaxSize), - } - // 是否开启了MIMEType限制 - if handler.Policy.OptionsSerialized.MimeType != "" { - putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType - } - - credential, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, ttl, true) - if err != nil { - return nil, fmt.Errorf("failed to init parts: %w", err) - } - - credential.SessionID = uploadSession.Key - credential.ChunkSize = handler.Policy.OptionsSerialized.ChunkSize - - uploadSession.UploadURL = credential.UploadURLs[0] - uploadSession.Credential = credential.Credential - - return credential, nil -} - -// getUploadCredential 签名上传策略并创建上传会话 -func (handler *Driver) getUploadCredential(ctx context.Context, policy storage.PutPolicy, file *fsctx.UploadTaskInfo, TTL int64, resume bool) (*serializer.UploadCredential, error) { - // 上传凭证 - policy.Expires = uint64(TTL) - upToken := policy.UploadToken(handler.mac) - - // 初始化分片上传 - resumeUploader := storage.NewResumeUploaderV2(handler.cfg) - upHost, err := resumeUploader.UpHost(handler.Policy.AccessKey, handler.Policy.BucketName) - if err != nil { - return nil, err - } - - ret := &storage.InitPartsRet{} - if resume { - err = resumeUploader.InitParts(ctx, upToken, upHost, handler.Policy.BucketName, file.SavePath, true, ret) - } - - return &serializer.UploadCredential{ - UploadURLs: []string{upHost + "/buckets/" + handler.Policy.BucketName + "/objects/" + base64.URLEncoding.EncodeToString([]byte(file.SavePath)) + "/uploads/" + ret.UploadID}, - Credential: upToken, - }, err -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - resumeUploader := storage.NewResumeUploaderV2(handler.cfg) - return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0) -} diff --git a/pkg/filesystem/driver/remote/client.go b/pkg/filesystem/driver/remote/client.go deleted file mode 100644 index b1b1804d..00000000 --- a/pkg/filesystem/driver/remote/client.go +++ /dev/null @@ -1,195 +0,0 @@ -package remote - -import ( - "context" - "encoding/json" - "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" - "io" - "net/http" - "net/url" - "path" - "strings" - "time" -) - -const ( - basePath = "/api/v3/slave/" - OverwriteHeader = auth.CrHeaderPrefix + "Overwrite" - chunkRetrySleep = time.Duration(5) * time.Second -) - -// Client to operate uploading to remote slave server -type Client interface { - // CreateUploadSession creates remote upload session - CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error - // GetUploadURL signs an url for uploading file - GetUploadURL(ttl int64, sessionID string) (string, string, error) - // Upload uploads file to remote server - Upload(ctx context.Context, file fsctx.FileHeader) error - // DeleteUploadSession deletes remote upload session - DeleteUploadSession(ctx context.Context, sessionID string) error -} - -// NewClient creates new Client from given policy -func NewClient(policy *model.Policy) (Client, error) { - authInstance := auth.HMACAuth{[]byte(policy.SecretKey)} - serverURL, err := url.Parse(policy.Server) - if err != nil { - return nil, err - } - - base, _ := url.Parse(basePath) - signTTL := model.GetIntSetting("slave_api_timeout", 60) - - return &remoteClient{ - policy: policy, - authInstance: authInstance, - httpClient: request.NewClient( - request.WithEndpoint(serverURL.ResolveReference(base).String()), - request.WithCredential(authInstance, int64(signTTL)), - request.WithMasterMeta(), - request.WithSlaveMeta(policy.AccessKey), - ), - }, nil -} - -type remoteClient struct { - policy *model.Policy - authInstance auth.Auth - httpClient request.Client -} - -func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error { - ttl := model.GetIntSetting("upload_session_timeout", 86400) - fileInfo := file.Info() - session := &serializer.UploadSession{ - Key: uuid.Must(uuid.NewV4()).String(), - VirtualPath: fileInfo.VirtualPath, - Name: fileInfo.FileName, - Size: fileInfo.Size, - SavePath: fileInfo.SavePath, - LastModified: fileInfo.LastModified, - Policy: *c.policy, - } - - // Create upload session - overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite - if err := c.CreateUploadSession(ctx, session, int64(ttl), overwrite); err != nil { - return fmt.Errorf("failed to create upload session: %w", err) - } - - // Initial chunk groups - chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{ - Max: model.GetIntSetting("chunk_retries", 5), - Sleep: chunkRetrySleep, - }, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer"))) - - uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error { - return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length()) - } - - // upload chunks - for chunks.Next() { - if err := chunks.Process(uploadFunc); err != nil { - if err := c.DeleteUploadSession(ctx, session.Key); err != nil { - util.Log().Warning("failed to delete upload session: %s", err) - } - - return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err) - } - } - - return nil -} - -func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error { - resp, err := c.httpClient.Request( - "DELETE", - "upload/"+sessionID, - nil, - request.WithContext(ctx), - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return err - } - - if resp.Code != 0 { - return serializer.NewErrorFromResponse(resp) - } - - return nil -} - -func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error { - reqBodyEncoded, err := json.Marshal(map[string]interface{}{ - "session": session, - "ttl": ttl, - "overwrite": overwrite, - }) - if err != nil { - return err - } - - bodyReader := strings.NewReader(string(reqBodyEncoded)) - resp, err := c.httpClient.Request( - "PUT", - "upload", - bodyReader, - request.WithContext(ctx), - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return err - } - - if resp.Code != 0 { - return serializer.NewErrorFromResponse(resp) - } - - return nil -} - -func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string, error) { - base, err := url.Parse(c.policy.Server) - if err != nil { - return "", "", err - } - - base.Path = path.Join(base.Path, basePath, "upload", sessionID) - req, err := http.NewRequest("POST", base.String(), nil) - if err != nil { - return "", "", err - } - - req = auth.SignRequest(c.authInstance, req, ttl) - return req.URL.String(), req.Header["Authorization"][0], nil -} - -func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error { - resp, err := c.httpClient.Request( - "POST", - fmt.Sprintf("upload/%s?chunk=%d", sessionID, index), - chunk, - request.WithContext(ctx), - request.WithTimeout(time.Duration(0)), - request.WithContentLength(size), - request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}), - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return err - } - - if resp.Code != 0 { - return serializer.NewErrorFromResponse(resp) - } - - return nil -} diff --git a/pkg/filesystem/driver/remote/client_test.go b/pkg/filesystem/driver/remote/client_test.go deleted file mode 100644 index c195521a..00000000 --- a/pkg/filesystem/driver/remote/client_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package remote - -import ( - "context" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "strings" - "testing" -) - -func TestNewClient(t *testing.T) { - a := assert.New(t) - policy := &model.Policy{} - - // 无法解析服务端url - { - policy.Server = string([]byte{0x7f}) - c, err := NewClient(policy) - a.Error(err) - a.Nil(c) - } - - // 成功 - { - policy.Server = "" - c, err := NewClient(policy) - a.NoError(err) - a.NotNil(c) - } -} - -func TestRemoteClient_Upload(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - // 无法创建上传会话 - { - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 分片上传失败,成功删除上传会话 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 分片上传失败,无法删除上传会话 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error2"), - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 成功 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.NoError(err) - clientMock.AssertExpectations(t) - } -} - -func TestRemoteClient_CreateUploadSessionFailed(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} - -func TestRemoteClient_UploadChunkFailed(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)), - }, - }) - err := c.(*remoteClient).uploadChunk(context.Background(), "", 0, strings.NewReader(""), false, 0) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} - -func TestRemoteClient_GetUploadURL(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - // url 解析失败 - { - c.(*remoteClient).policy.Server = string([]byte{0x7f}) - res, sign, err := c.GetUploadURL(0, "") - a.Error(err) - a.Empty(res) - a.Empty(sign) - } - - // 成功 - { - c.(*remoteClient).policy.Server = "" - res, sign, err := c.GetUploadURL(0, "") - a.NoError(err) - a.NotEmpty(res) - a.NotEmpty(sign) - } -} diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go deleted file mode 100644 index 5918f3b4..00000000 --- a/pkg/filesystem/driver/remote/handler.go +++ /dev/null @@ -1,311 +0,0 @@ -package remote - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// Driver 远程存储策略适配器 -type Driver struct { - Client request.Client - Policy *model.Policy - AuthInstance auth.Auth - - uploadClient Client -} - -// NewDriver initializes a new Driver from policy -// TODO: refactor all method into upload client -func NewDriver(policy *model.Policy) (*Driver, error) { - client, err := NewClient(policy) - if err != nil { - return nil, err - } - - return &Driver{ - Policy: policy, - Client: request.NewClient(), - AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)}, - uploadClient: client, - }, nil -} - -// List 列取文件 -func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - var res []response.Object - - reqBody := serializer.ListRequest{ - Path: path, - Recursive: recursive, - } - reqBodyEncoded, err := json.Marshal(reqBody) - if err != nil { - return res, err - } - - // 发送列表请求 - bodyReader := strings.NewReader(string(reqBodyEncoded)) - signTTL := model.GetIntSetting("slave_api_timeout", 60) - resp, err := handler.Client.Request( - "POST", - handler.getAPIUrl("list"), - bodyReader, - request.WithCredential(handler.AuthInstance, int64(signTTL)), - request.WithMasterMeta(), - ).CheckHTTPResponse(200).DecodeResponse() - if err != nil { - return res, err - } - - // 处理列取结果 - if resp.Code != 0 { - return res, errors.New(resp.Error) - } - - if resStr, ok := resp.Data.(string); ok { - err = json.Unmarshal([]byte(resStr), &res) - if err != nil { - return res, err - } - } - - return res, nil -} - -// getAPIUrl 获取接口请求地址 -func (handler *Driver) getAPIUrl(scope string, routes ...string) string { - serverURL, err := url.Parse(handler.Policy.Server) - if err != nil { - return "" - } - var controller *url.URL - - switch scope { - case "delete": - controller, _ = url.Parse("/api/v3/slave/delete") - case "thumb": - controller, _ = url.Parse("/api/v3/slave/thumb") - case "list": - controller, _ = url.Parse("/api/v3/slave/list") - default: - controller = serverURL - } - - for _, r := range routes { - controller.Path = path.Join(controller.Path, r) - } - - return serverURL.ResolveReference(controller).String() -} - -// Get 获取文件内容 -func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 尝试获取速度限制 - speedLimit := 0 - if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { - speedLimit = user.Group.SpeedLimit - } - - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit) - if err != nil { - return nil, err - } - - // 获取文件数据流 - resp, err := handler.Client.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithTimeout(time.Duration(0)), - request.WithMasterMeta(), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - return handler.uploadClient.Upload(ctx, file) -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - // 封装接口请求正文 - reqBody := serializer.RemoteDeleteRequest{ - Files: files, - } - reqBodyEncoded, err := json.Marshal(reqBody) - if err != nil { - return files, err - } - - // 发送删除请求 - bodyReader := strings.NewReader(string(reqBodyEncoded)) - signTTL := model.GetIntSetting("slave_api_timeout", 60) - resp, err := handler.Client.Request( - "POST", - handler.getAPIUrl("delete"), - bodyReader, - request.WithCredential(handler.AuthInstance, int64(signTTL)), - request.WithMasterMeta(), - request.WithSlaveMeta(handler.Policy.AccessKey), - ).CheckHTTPResponse(200).GetResponse() - if err != nil { - return files, err - } - - // 处理删除结果 - var reqResp serializer.Response - err = json.Unmarshal([]byte(resp), &reqResp) - if err != nil { - return files, err - } - if reqResp.Code != 0 { - var failedResp serializer.RemoteDeleteRequest - if failed, ok := reqResp.Data.(string); ok { - err = json.Unmarshal([]byte(failed), &failedResp) - if err == nil { - return failedResp.Files, errors.New(reqResp.Error) - } - } - return files, errors.New("unknown format of returned response") - } - - return []string{}, nil -} - -// Thumb 获取文件缩略图 -func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // quick check by extension name - supported := []string{"png", "jpg", "jpeg", "gif"} - if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 { - supported = handler.Policy.OptionsSerialized.ThumbExts - } - - if !util.IsInExtensionList(supported, file.Name) { - return nil, driver.ErrorThumbNotSupported - } - - sourcePath := base64.RawURLEncoding.EncodeToString([]byte(file.SourceName)) - thumbURL := fmt.Sprintf("%s/%s/%s", handler.getAPIUrl("thumb"), sourcePath, filepath.Ext(file.Name)) - ttl := model.GetIntSetting("preview_timeout", 60) - signedThumbURL, err := auth.SignURI(handler.AuthInstance, thumbURL, int64(ttl)) - if err != nil { - return nil, err - } - - return &response.ContentResponse{ - Redirect: true, - URL: signedThumbURL.String(), - }, nil -} - -// Source 获取外链URL -func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - // 尝试从上下文获取文件名 - fileName := "file" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - serverURL, err := url.Parse(handler.Policy.Server) - if err != nil { - return "", errors.New("无法解析远程服务端地址") - } - - // 是否启用了CDN - if handler.Policy.BaseURL != "" { - cdnURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - serverURL = cdnURL - } - - var ( - signedURI *url.URL - controller = "/api/v3/slave/download" - ) - if !isDownload { - controller = "/api/v3/slave/source" - } - - // 签名下载地址 - sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path)) - signedURI, err = auth.SignURI( - handler.AuthInstance, - fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, url.PathEscape(fileName)), - ttl, - ) - - if err != nil { - return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign URL", err) - } - - finalURL := serverURL.ResolveReference(signedURI).String() - return finalURL, nil - -} - -// Token 获取上传策略和认证Token -func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret)) - apiURL := siteURL.ResolveReference(apiBaseURI) - - // 在从机端创建上传会话 - uploadSession.Callback = apiURL.String() - if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl, false); err != nil { - return nil, err - } - - // 获取上传地址 - uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key) - if err != nil { - return nil, fmt.Errorf("failed to sign upload url: %w", err) - } - - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - UploadURLs: []string{uploadURL}, - Credential: sign, - }, nil -} - -// 取消上传凭证 -func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key) -} diff --git a/pkg/filesystem/driver/remote/handler_test.go b/pkg/filesystem/driver/remote/handler_test.go deleted file mode 100644 index 4f6f2392..00000000 --- a/pkg/filesystem/driver/remote/handler_test.go +++ /dev/null @@ -1,460 +0,0 @@ -package remote - -import ( - "context" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/remoteclientmock" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestNewDriver(t *testing.T) { - a := assert.New(t) - - // remoteClient 初始化失败 - { - d, err := NewDriver(&model.Policy{Server: string([]byte{0x7f})}) - a.Error(err) - a.Nil(d) - } - - // 成功 - { - d, err := NewDriver(&model.Policy{}) - a.NoError(err) - a.NotNil(d) - } -} - -func TestHandler_Source(t *testing.T) { - asserts := assert.New(t) - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 无法获取上下文 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - res, err := handler.Source(ctx, "", 0, true, 0) - asserts.NoError(err) - asserts.NotEmpty(res) - } - - // 成功 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/download/0") - } - - // 成功 自定义CDN - { - handler := Driver{ - Policy: &model.Policy{Server: "/", BaseURL: "https://cqu.edu.cn"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/download/0") - asserts.Contains(res, "https://cqu.edu.cn") - } - - // 解析失败 自定义CDN - { - handler := Driver{ - Policy: &model.Policy{Server: "/", BaseURL: string([]byte{0x7f})}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.Error(err) - asserts.Empty(res) - } - - // 成功 预览 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, false, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/source/0") - } -} - -type ClientMock struct { - testMock.Mock -} - -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - args := m.Called(method, target, body, opts) - return args.Get(0).(*request.Response) -} - -func TestHandler_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - cache.Set("setting_slave_api_timeout", "60", 0) - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Len(failed, 0) - - } - - // 结果解析失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(failed, 2) - } - - // 一个失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203,"data":"{\"files\":[\"1\"]}"}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(failed, 1) - } -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - cache.Set("setting_slave_api_timeout", "60", 0) - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"[{}]"}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Len(res, 1) - - } - - // 响应解析失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"233"}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 0) - } - - // 从机返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 0) - } -} - -func TestHandler_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - - // 成功 - { - ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - nil, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - resp, err := handler.Get(ctx, "/test.txt") - clientMock.AssertExpectations(t) - asserts.NotNil(resp) - asserts.NoError(err) - } - - // 请求失败 - { - ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - nil, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - resp, err := handler.Get(ctx, "/test.txt") - clientMock.AssertExpectations(t) - asserts.Nil(resp) - asserts.Error(err) - } -} - -func TestHandler_Put(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{ - Type: "remote", - SecretKey: "test", - Server: "http://test.com", - }) - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("Upload", testMock.Anything, testMock.Anything).Return(errors.New("error")) - a.Error(handler.Put(context.Background(), &fsctx.FileStream{})) - clientMock.AssertExpectations(t) -} - -func TestHandler_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - Type: "remote", - SecretKey: "test", - Server: "http://test.com", - OptionsSerialized: model.PolicyOption{ - ThumbExts: []string{"txt"}, - }, - }, - AuthInstance: auth.HMACAuth{}, - } - file := &model.File{ - Name: "1.txt", - SourceName: "1.txt", - } - ctx := context.Background() - asserts.NoError(cache.Set("setting_preview_timeout", "60", 0)) - - // no error - { - resp, err := handler.Thumb(ctx, file) - asserts.NoError(err) - asserts.True(resp.Redirect) - } - - // ext not support - { - file.Name = "1.jpg" - resp, err := handler.Thumb(ctx, file) - asserts.ErrorIs(err, driver.ErrorThumbNotSupported) - asserts.Nil(resp) - } -} - -func TestHandler_Token(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{}) - - // 无法创建上传会话 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(errors.New("error")) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - a.Nil(res) - clientMock.AssertExpectations(t) - } - - // 无法创建上传地址 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil) - clientMock.On("GetUploadURL", int64(10), "").Return("", "", errors.New("error")) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - a.Nil(res) - clientMock.AssertExpectations(t) - } - - // 成功 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil) - clientMock.On("GetUploadURL", int64(10), "").Return("1", "2", nil) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.NoError(err) - a.NotNil(res) - a.Equal("1", res.UploadURLs[0]) - a.Equal("2", res.Credential) - clientMock.AssertExpectations(t) - } -} - -func TestDriver_CancelToken(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{}) - - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("DeleteUploadSession", testMock.Anything, "key").Return(errors.New("error")) - err := handler.CancelToken(context.Background(), &serializer.UploadSession{Key: "key"}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go deleted file mode 100644 index 56a7aaae..00000000 --- a/pkg/filesystem/driver/s3/handler.go +++ /dev/null @@ -1,440 +0,0 @@ -package s3 - -import ( - "context" - "errors" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "io" - "net/http" - "net/url" - "path" - "path/filepath" - "strings" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Driver 适配器模板 -type Driver struct { - Policy *model.Policy - sess *session.Session - svc *s3.S3 -} - -// UploadPolicy S3上传策略 -type UploadPolicy struct { - Expiration string `json:"expiration"` - Conditions []interface{} `json:"conditions"` -} - -// MetaData 文件信息 -type MetaData struct { - Size uint64 - Etag string -} - -func NewDriver(policy *model.Policy) (*Driver, error) { - if policy.OptionsSerialized.ChunkSize == 0 { - policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB - } - - driver := &Driver{ - Policy: policy, - } - - return driver, driver.InitS3Client() -} - -// InitS3Client 初始化S3会话 -func (handler *Driver) InitS3Client() error { - if handler.Policy == nil { - return errors.New("empty policy") - } - - if handler.svc == nil { - // 初始化会话 - sess, err := session.NewSession(&aws.Config{ - Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""), - Endpoint: &handler.Policy.Server, - Region: &handler.Policy.OptionsSerialized.Region, - S3ForcePathStyle: &handler.Policy.OptionsSerialized.S3ForcePathStyle, - }) - - if err != nil { - return err - } - handler.sess = sess - handler.svc = s3.New(sess) - } - return nil -} - -// List 列出给定路径下的文件 -func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - // 初始化列目录参数 - base = strings.TrimPrefix(base, "/") - if base != "" { - base += "/" - } - - opt := &s3.ListObjectsInput{ - Bucket: &handler.Policy.BucketName, - Prefix: &base, - MaxKeys: aws.Int64(1000), - } - - // 是否为递归列出 - if !recursive { - opt.Delimiter = aws.String("/") - } - - var ( - objects []*s3.Object - commons []*s3.CommonPrefix - ) - - for { - res, err := handler.svc.ListObjectsWithContext(ctx, opt) - if err != nil { - return nil, err - } - objects = append(objects, res.Contents...) - commons = append(commons, res.CommonPrefixes...) - - // 如果本次未列取完,则继续使用marker获取结果 - if *res.IsTruncated { - opt.Marker = res.NextMarker - } else { - break - } - } - - // 处理列取结果 - res := make([]response.Object, 0, len(objects)+len(commons)) - - // 处理目录 - for _, object := range commons { - rel, err := filepath.Rel(*opt.Prefix, *object.Prefix) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(*object.Prefix), - RelativePath: filepath.ToSlash(rel), - Size: 0, - IsDir: true, - LastModify: time.Now(), - }) - } - // 处理文件 - for _, object := range objects { - rel, err := filepath.Rel(*opt.Prefix, *object.Key) - if err != nil { - continue - } - res = append(res, response.Object{ - Name: path.Base(*object.Key), - Source: *object.Key, - RelativePath: filepath.ToSlash(rel), - Size: uint64(*object.Size), - IsDir: false, - LastModify: time.Now(), - }) - } - - return res, nil - -} - -// Get 获取文件 -func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - // 获取文件数据流 - client := request.NewClient() - resp, err := client.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithHeader( - http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}}, - ), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil -} - -// Put 将文件流保存到指定目录 -func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { - return err - } - - uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) { - u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize) - }) - - dst := file.Info().SavePath - _, err := uploader.Upload(&s3manager.UploadInput{ - Bucket: &handler.Policy.BucketName, - Key: &dst, - Body: io.LimitReader(file, int64(file.Info().Size)), - }) - - if err != nil { - return err - } - - return nil -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - failed := make([]string, 0, len(files)) - deleted := make([]string, 0, len(files)) - - keys := make([]*s3.ObjectIdentifier, 0, len(files)) - for _, file := range files { - filePath := file - keys = append(keys, &s3.ObjectIdentifier{Key: &filePath}) - } - - // 发送异步删除请求 - res, err := handler.svc.DeleteObjects( - &s3.DeleteObjectsInput{ - Bucket: &handler.Policy.BucketName, - Delete: &s3.Delete{ - Objects: keys, - }, - }) - - if err != nil { - return files, err - } - - // 统计未删除的文件 - for _, deleteRes := range res.Deleted { - deleted = append(deleted, *deleteRes.Key) - } - failed = util.SliceDifference(files, deleted) - - return failed, nil - -} - -// Thumb 获取文件缩略图 -func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - return nil, driver.ErrorThumbNotSupported -} - -// Source 获取外链URL -func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - - // 尝试从上下文获取文件名 - fileName := "" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - // 初始化客户端 - if err := handler.InitS3Client(); err != nil { - return "", err - } - - contentDescription := aws.String("attachment; filename=\"" + url.PathEscape(fileName) + "\"") - if !isDownload { - contentDescription = nil - } - req, _ := handler.svc.GetObjectRequest( - &s3.GetObjectInput{ - Bucket: &handler.Policy.BucketName, - Key: &path, - ResponseContentDisposition: contentDescription, - }) - - signedURL, err := req.Presign(time.Duration(ttl) * time.Second) - if err != nil { - return "", err - } - - // 将最终生成的签名URL域名换成用户自定义的加速域名(如果有) - finalURL, err := url.Parse(signedURL) - if err != nil { - return "", err - } - - // 公有空间替换掉Key及不支持的头 - if !handler.Policy.IsPrivate { - finalURL.RawQuery = "" - } - - if handler.Policy.BaseURL != "" { - cdnURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - finalURL.Host = cdnURL.Host - finalURL.Scheme = cdnURL.Scheme - } - - return finalURL.String(), nil -} - -// Token 获取上传策略和认证Token -func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 检查文件是否存在 - fileInfo := file.Info() - if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil { - return nil, fmt.Errorf("file already exist") - } - - // 创建分片上传 - expires := time.Now().Add(time.Duration(ttl) * time.Second) - res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ - Bucket: &handler.Policy.BucketName, - Key: &fileInfo.SavePath, - Expires: &expires, - ContentType: aws.String(fileInfo.DetectMimeType()), - }) - if err != nil { - return nil, fmt.Errorf("failed to create multipart upload: %w", err) - } - - uploadSession.UploadID = *res.UploadId - - // 为每个分片签名上传 URL - chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false) - urls := make([]string, chunks.Num()) - for chunks.Next() { - err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error { - signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{ - Bucket: &handler.Policy.BucketName, - Key: &fileInfo.SavePath, - PartNumber: aws.Int64(int64(c.Index() + 1)), - UploadId: res.UploadId, - }) - - signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second) - if err != nil { - return err - } - - urls[c.Index()] = signedURL - return nil - }) - if err != nil { - return nil, err - } - } - - // 签名完成分片上传的请求URL - signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{ - Bucket: &handler.Policy.BucketName, - Key: &fileInfo.SavePath, - UploadId: res.UploadId, - }) - - signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second) - if err != nil { - return nil, err - } - - // 生成上传凭证 - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - ChunkSize: handler.Policy.OptionsSerialized.ChunkSize, - UploadID: *res.UploadId, - UploadURLs: urls, - CompleteURL: signedURL, - }, nil -} - -// Meta 获取文件信息 -func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) { - res, err := handler.svc.HeadObject( - &s3.HeadObjectInput{ - Bucket: &handler.Policy.BucketName, - Key: &path, - }) - - if err != nil { - return nil, err - } - - return &MetaData{ - Size: uint64(*res.ContentLength), - Etag: *res.ETag, - }, nil - -} - -// CORS 创建跨域策略 -func (handler *Driver) CORS() error { - rule := s3.CORSRule{ - AllowedMethods: aws.StringSlice([]string{ - "GET", - "POST", - "PUT", - "DELETE", - "HEAD", - }), - AllowedOrigins: aws.StringSlice([]string{"*"}), - AllowedHeaders: aws.StringSlice([]string{"*"}), - ExposeHeaders: aws.StringSlice([]string{"ETag"}), - MaxAgeSeconds: aws.Int64(3600), - } - - _, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{ - Bucket: &handler.Policy.BucketName, - CORSConfiguration: &s3.CORSConfiguration{ - CORSRules: []*s3.CORSRule{&rule}, - }, - }) - - return err -} - -// 取消上传凭证 -func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - _, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ - UploadId: &uploadSession.UploadID, - Bucket: &handler.Policy.BucketName, - Key: &uploadSession.SavePath, - }) - return err -} diff --git a/pkg/filesystem/driver/shadow/masterinslave/errors.go b/pkg/filesystem/driver/shadow/masterinslave/errors.go deleted file mode 100644 index 27d04288..00000000 --- a/pkg/filesystem/driver/shadow/masterinslave/errors.go +++ /dev/null @@ -1,7 +0,0 @@ -package masterinslave - -import "errors" - -var ( - ErrNotImplemented = errors.New("this method of shadowed policy is not implemented") -) diff --git a/pkg/filesystem/driver/shadow/masterinslave/handler.go b/pkg/filesystem/driver/shadow/masterinslave/handler.go deleted file mode 100644 index d3f376aa..00000000 --- a/pkg/filesystem/driver/shadow/masterinslave/handler.go +++ /dev/null @@ -1,60 +0,0 @@ -package masterinslave - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Driver 影子存储策略,用于在从机端上传文件 -type Driver struct { - master cluster.Node - handler driver.Handler - policy *model.Policy -} - -// NewDriver 返回新的处理器 -func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler { - return &Driver{ - master: master, - handler: handler, - policy: policy, - } -} - -func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - return d.handler.Put(ctx, file) -} - -func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - return d.handler.Delete(ctx, files) -} - -func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - return "", ErrNotImplemented -} - -func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - return nil, ErrNotImplemented -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil -} diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/errors.go b/pkg/filesystem/driver/shadow/slaveinmaster/errors.go deleted file mode 100644 index 6acadc89..00000000 --- a/pkg/filesystem/driver/shadow/slaveinmaster/errors.go +++ /dev/null @@ -1,9 +0,0 @@ -package slaveinmaster - -import "errors" - -var ( - ErrNotImplemented = errors.New("this method of shadowed policy is not implemented") - ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node") - ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result") -) diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go deleted file mode 100644 index bfcac269..00000000 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ /dev/null @@ -1,124 +0,0 @@ -package slaveinmaster - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "net/url" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 -type Driver struct { - node cluster.Node - handler driver.Handler - policy *model.Policy - client request.Client -} - -// NewDriver 返回新的从机指派处理器 -func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler { - var endpoint *url.URL - if serverURL, err := url.Parse(node.DBModel().Server); err == nil { - var controller *url.URL - controller, _ = url.Parse("/api/v3/slave/") - endpoint = serverURL.ResolveReference(controller) - } - - signTTL := model.GetIntSetting("slave_api_timeout", 60) - return &Driver{ - node: node, - handler: handler, - policy: policy, - client: request.NewClient( - request.WithMasterMeta(), - request.WithTimeout(time.Duration(signTTL)*time.Second), - request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)), - request.WithEndpoint(endpoint.String()), - ), - } -} - -// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略 -func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - fileInfo := file.Info() - req := serializer.SlaveTransferReq{ - Src: fileInfo.Src, - Dst: fileInfo.SavePath, - Policy: d.policy, - } - - body, err := json.Marshal(req) - if err != nil { - return err - } - - // 订阅转存结果 - resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0) - defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan) - - res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)). - CheckHTTPResponse(200). - DecodeResponse() - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - // 等待转存结果或者超时 - waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800) - select { - case <-time.After(time.Duration(waitTimeout) * time.Second): - return ErrWaitResultTimeout - case msg := <-resChan: - if msg.Event != serializer.SlaveTransferSuccess { - return errors.New(msg.Content.(serializer.SlaveTransferResult).Error) - } - } - - return nil -} - -func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - return d.handler.Delete(ctx, files) -} - -func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - return "", ErrNotImplemented -} - -func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - return nil, ErrNotImplemented -} - -func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - return nil, ErrNotImplemented -} - -// 取消上传凭证 -func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil -} diff --git a/pkg/filesystem/driver/upyun/handler.go b/pkg/filesystem/driver/upyun/handler.go deleted file mode 100644 index a9d18d61..00000000 --- a/pkg/filesystem/driver/upyun/handler.go +++ /dev/null @@ -1,358 +0,0 @@ -package upyun - -import ( - "context" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "path" - "strconv" - "strings" - "sync" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/upyun/go-sdk/upyun" -) - -// UploadPolicy 又拍云上传策略 -type UploadPolicy struct { - Bucket string `json:"bucket"` - SaveKey string `json:"save-key"` - Expiration int64 `json:"expiration"` - CallbackURL string `json:"notify-url"` - ContentLength uint64 `json:"content-length"` - ContentLengthRange string `json:"content-length-range,omitempty"` - AllowFileType string `json:"allow-file-type,omitempty"` -} - -// Driver 又拍云策略适配器 -type Driver struct { - Policy *model.Policy -} - -func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) { - base = strings.TrimPrefix(base, "/") - - // 用于接受SDK返回对象的chan - objChan := make(chan *upyun.FileInfo) - objects := []*upyun.FileInfo{} - - // 列取配置 - listConf := &upyun.GetObjectsConfig{ - Path: "/" + base, - ObjectsChan: objChan, - MaxListTries: 1, - } - // 递归列取时不限制递归次数 - if recursive { - listConf.MaxListLevel = -1 - } - - // 启动一个goroutine收集列取的对象信 - wg := &sync.WaitGroup{} - wg.Add(1) - go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) { - defer wg.Done() - for { - file, ok := <-input - if !ok { - return - } - *output = append(*output, file) - } - }(objChan, &objects, wg) - - up := upyun.NewUpYun(&upyun.UpYunConfig{ - Bucket: handler.Policy.BucketName, - Operator: handler.Policy.AccessKey, - Password: handler.Policy.SecretKey, - }) - - err := up.List(listConf) - if err != nil { - return nil, err - } - - wg.Wait() - - // 汇总处理列取结果 - res := make([]response.Object, 0, len(objects)) - for _, object := range objects { - res = append(res, response.Object{ - Name: path.Base(object.Name), - RelativePath: object.Name, - Source: path.Join(base, object.Name), - Size: uint64(object.Size), - IsDir: object.IsDir, - LastModify: object.Time, - }) - } - - return res, nil -} - -// Get 获取文件 -func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 获取文件源地址 - downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - // 获取文件数据流 - client := request.NewClient() - resp, err := client.Request( - "GET", - downloadURL, - nil, - request.WithContext(ctx), - request.WithHeader( - http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}}, - ), - request.WithTimeout(time.Duration(0)), - ).CheckHTTPResponse(200).GetRSCloser() - if err != nil { - return nil, err - } - - resp.SetFirstFakeChunk() - - // 尝试自主获取文件大小 - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - resp.SetContentLength(int64(file.Size)) - } - - return resp, nil - -} - -// Put 将文件流保存到指定目录 -func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error { - defer file.Close() - - up := upyun.NewUpYun(&upyun.UpYunConfig{ - Bucket: handler.Policy.BucketName, - Operator: handler.Policy.AccessKey, - Password: handler.Policy.SecretKey, - }) - err := up.Put(&upyun.PutObjectConfig{ - Path: file.Info().SavePath, - Reader: file, - }) - - return err -} - -// Delete 删除一个或多个文件, -// 返回未删除的文件,及遇到的最后一个错误 -func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { - up := upyun.NewUpYun(&upyun.UpYunConfig{ - Bucket: handler.Policy.BucketName, - Operator: handler.Policy.AccessKey, - Password: handler.Policy.SecretKey, - }) - - var ( - failed = make([]string, 0, len(files)) - lastErr error - currentIndex = 0 - indexLock sync.Mutex - failedLock sync.Mutex - wg sync.WaitGroup - routineNum = 4 - ) - wg.Add(routineNum) - - // upyun不支持批量操作,这里开四个协程并行操作 - for i := 0; i < routineNum; i++ { - go func() { - for { - // 取得待删除文件 - indexLock.Lock() - if currentIndex >= len(files) { - // 所有文件处理完成 - wg.Done() - indexLock.Unlock() - return - } - path := files[currentIndex] - currentIndex++ - indexLock.Unlock() - - // 发送异步删除请求 - err := up.Delete(&upyun.DeleteObjectConfig{ - Path: path, - Async: true, - }) - - // 处理错误 - if err != nil { - failedLock.Lock() - lastErr = err - failed = append(failed, path) - failedLock.Unlock() - } - } - }() - } - - wg.Wait() - - return failed, lastErr -} - -// Thumb 获取文件缩略图 -func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) { - // quick check by extension name - // https://help.upyun.com/knowledge-base/image/ - supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"} - if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 { - supported = handler.Policy.OptionsSerialized.ThumbExts - } - - if !util.IsInExtensionList(supported, file.Name) { - return nil, driver.ErrorThumbNotSupported - } - - var ( - thumbSize = [2]uint{400, 300} - ok = false - ) - if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok { - return nil, errors.New("failed to get thumbnail size") - } - - thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85) - - thumbParam := fmt.Sprintf("!/fwfh/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality) - thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - if err != nil { - return nil, err - } - - return &response.ContentResponse{ - Redirect: true, - URL: thumbURL, - }, nil -} - -// Source 获取外链URL -func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) { - // 尝试从上下文获取文件名 - fileName := "" - if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - fileName = file.Name - } - - sourceURL, err := url.Parse(handler.Policy.BaseURL) - if err != nil { - return "", err - } - - fileKey, err := url.Parse(url.PathEscape(path)) - if err != nil { - return "", err - } - - sourceURL = sourceURL.ResolveReference(fileKey) - - // 如果是下载文件URL - if isDownload { - query := sourceURL.Query() - query.Add("_upd", fileName) - sourceURL.RawQuery = query.Encode() - } - - return handler.signURL(ctx, sourceURL, ttl) -} - -func (handler Driver) signURL(ctx context.Context, path *url.URL, TTL int64) (string, error) { - if !handler.Policy.IsPrivate { - // 未开启Token防盗链时,直接返回 - return path.String(), nil - } - - etime := time.Now().Add(time.Duration(TTL) * time.Second).Unix() - signStr := fmt.Sprintf( - "%s&%d&%s", - handler.Policy.OptionsSerialized.Token, - etime, - path.Path, - ) - signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr))) - finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10) - - // 将签名添加到URL中 - query := path.Query() - query.Add("_upt", finalSign) - path.RawQuery = query.Encode() - - return path.String(), nil -} - -// Token 获取上传策略和认证Token -func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - // 生成回调地址 - siteURL := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/callback/upyun/" + uploadSession.Key) - apiURL := siteURL.ResolveReference(apiBaseURI) - - // 上传策略 - fileInfo := file.Info() - putPolicy := UploadPolicy{ - Bucket: handler.Policy.BucketName, - // TODO escape - SaveKey: fileInfo.SavePath, - Expiration: time.Now().Add(time.Duration(ttl) * time.Second).Unix(), - CallbackURL: apiURL.String(), - ContentLength: fileInfo.Size, - ContentLengthRange: fmt.Sprintf("0,%d", fileInfo.Size), - AllowFileType: strings.Join(handler.Policy.OptionsSerialized.FileType, ","), - } - - // 生成上传凭证 - policyJSON, err := json.Marshal(putPolicy) - if err != nil { - return nil, err - } - policyEncoded := base64.StdEncoding.EncodeToString(policyJSON) - - // 生成签名 - elements := []string{"POST", "/" + handler.Policy.BucketName, policyEncoded} - signStr := handler.Sign(ctx, elements) - - return &serializer.UploadCredential{ - SessionID: uploadSession.Key, - Policy: policyEncoded, - Credential: signStr, - UploadURLs: []string{"https://v0.api.upyun.com/" + handler.Policy.BucketName}, - }, nil -} - -// 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - return nil -} - -// Sign 计算又拍云的签名头 -func (handler Driver) Sign(ctx context.Context, elements []string) string { - password := fmt.Sprintf("%x", md5.Sum([]byte(handler.Policy.SecretKey))) - mac := hmac.New(sha1.New, []byte(password)) - value := strings.Join(elements, "&") - mac.Write([]byte(value)) - signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil))) - return fmt.Sprintf("UPYUN %s:%s", handler.Policy.AccessKey, signStr) -} diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go deleted file mode 100644 index d2670381..00000000 --- a/pkg/filesystem/errors.go +++ /dev/null @@ -1,26 +0,0 @@ -package filesystem - -import ( - "errors" - - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -var ( - ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil) - ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil) - ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil) - ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil) - ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil) - ErrClientCanceled = errors.New("Client canceled operation") - ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil) - ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil) - ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil) - ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil) - ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil) - ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil) - ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil) - ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil) - ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil) - ErrOneObjectOnly = serializer.ParamErr("You can only copy one object at the same time", nil) -) diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go deleted file mode 100644 index a2ddbb1b..00000000 --- a/pkg/filesystem/file.go +++ /dev/null @@ -1,387 +0,0 @@ -package filesystem - -import ( - "context" - "fmt" - "io" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/juju/ratelimit" -) - -/* ============ - 文件相关 - ============ -*/ - -// 限速后的ReaderSeeker -type lrs struct { - response.RSCloser - r io.Reader -} - -func (r lrs) Read(p []byte) (int, error) { - return r.r.Read(p) -} - -// withSpeedLimit 给原有的ReadSeeker加上限速 -func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser { - // 如果用户组有速度限制,就返回限制流速的ReaderSeeker - if fs.User.Group.SpeedLimit != 0 { - speed := fs.User.Group.SpeedLimit - bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed)) - lrs := lrs{rs, ratelimit.Reader(rs, bucket)} - return lrs - } - // 否则返回原始流 - return rs - -} - -// AddFile 新增文件记录 -func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fsctx.FileHeader) (*model.File, error) { - // 添加文件记录前的钩子 - err := fs.Trigger(ctx, "BeforeAddFile", file) - if err != nil { - return nil, err - } - - uploadInfo := file.Info() - newFile := model.File{ - Name: uploadInfo.FileName, - SourceName: uploadInfo.SavePath, - UserID: fs.User.ID, - Size: uploadInfo.Size, - FolderID: parent.ID, - PolicyID: fs.Policy.ID, - MetadataSerialized: uploadInfo.Metadata, - UploadSessionID: uploadInfo.UploadSessionID, - } - - err = newFile.Create() - - if err != nil { - if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil { - util.Log().Debug("AfterValidateFailed hook execution failed: %s", err) - } - return nil, ErrFileExisted.WithError(err) - } - - fs.User.Storage += newFile.Size - return &newFile, nil -} - -// GetPhysicalFileContent 根据文件物理路径获取文件流 -func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) { - // 重设上传策略 - fs.Policy = &model.Policy{Type: "local"} - _ = fs.DispatchHandler() - - // 获取文件流 - rs, err := fs.Handler.Get(ctx, path) - if err != nil { - return nil, err - } - - return fs.withSpeedLimit(rs), nil -} - -// Preview 预览文件 -// -// path - 文件虚拟路径 -// isText - 是否为文本文件,文本文件会忽略重定向,直接由 -// 服务端拉取中转给用户,故会对文件大小进行限制 -func (fs *FileSystem) Preview(ctx context.Context, id uint, isText bool) (*response.ContentResponse, error) { - err := fs.resetFileIDIfNotExist(ctx, id) - if err != nil { - return nil, err - } - - // 如果是文本文件预览,需要检查大小限制 - sizeLimit := model.GetIntSetting("maxEditSize", 2<<20) - if isText && fs.FileTarget[0].Size > uint64(sizeLimit) { - return nil, ErrFileSizeTooBig - } - - // 是否直接返回文件内容 - if isText || fs.Policy.IsDirectlyPreview() { - resp, err := fs.GetDownloadContent(ctx, id) - if err != nil { - return nil, err - } - return &response.ContentResponse{ - Redirect: false, - Content: resp, - }, nil - } - - // 否则重定向到签名的预览URL - ttl := model.GetIntSetting("preview_timeout", 60) - previewURL, err := fs.SignURL(ctx, &fs.FileTarget[0], int64(ttl), false) - if err != nil { - return nil, err - } - return &response.ContentResponse{ - Redirect: true, - URL: previewURL, - MaxAge: ttl, - }, nil - -} - -// GetDownloadContent 获取用于下载的文件流 -func (fs *FileSystem) GetDownloadContent(ctx context.Context, id uint) (response.RSCloser, error) { - // 获取原始文件流 - rs, err := fs.GetContent(ctx, id) - if err != nil { - return nil, err - } - - // 返回限速处理后的文件流 - return fs.withSpeedLimit(rs), nil - -} - -// GetContent 获取文件内容,path为虚拟路径 -func (fs *FileSystem) GetContent(ctx context.Context, id uint) (response.RSCloser, error) { - err := fs.resetFileIDIfNotExist(ctx, id) - if err != nil { - return nil, err - } - ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0]) - - // 获取文件流 - rs, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName) - if err != nil { - return nil, ErrIO.WithError(err) - } - - return rs, nil -} - -// deleteGroupedFile 对分组好的文件执行删除操作, -// 返回每个分组失败的文件列表 -func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*model.File) map[uint][]string { - // 失败的文件列表 - // TODO 并行删除 - failed := make(map[uint][]string, len(files)) - thumbs := make([]string, 0) - - for policyID, toBeDeletedFiles := range files { - // 列举出需要物理删除的文件的物理路径 - sourceNamesAll := make([]string, 0, len(toBeDeletedFiles)) - uploadSessions := make([]*serializer.UploadSession, 0, len(toBeDeletedFiles)) - - for i := 0; i < len(toBeDeletedFiles); i++ { - sourceNamesAll = append(sourceNamesAll, toBeDeletedFiles[i].SourceName) - - if toBeDeletedFiles[i].UploadSessionID != nil { - if session, ok := cache.Get(UploadSessionCachePrefix + *toBeDeletedFiles[i].UploadSessionID); ok { - uploadSession := session.(serializer.UploadSession) - uploadSessions = append(uploadSessions, &uploadSession) - } - } - - // Check if sidecar thumb file exist - if model.IsTrueVal(toBeDeletedFiles[i].MetadataSerialized[model.ThumbSidecarMetadataKey]) { - thumbs = append(thumbs, toBeDeletedFiles[i].ThumbFile()) - } - } - - // 切换上传策略 - fs.Policy = toBeDeletedFiles[0].GetPolicy() - err := fs.DispatchHandler() - if err != nil { - failed[policyID] = sourceNamesAll - continue - } - - // 取消上传会话 - for _, upSession := range uploadSessions { - if err := fs.Handler.CancelToken(ctx, upSession); err != nil { - util.Log().Warning("Failed to cancel upload session for %q: %s", upSession.Name, err) - } - - cache.Deletes([]string{upSession.Key}, UploadSessionCachePrefix) - } - - // 执行删除 - toBeDeletedSrcs := append(sourceNamesAll, thumbs...) - failedFile, _ := fs.Handler.Delete(ctx, toBeDeletedSrcs) - - // Exclude failed results related to thumb file - failed[policyID] = util.SliceDifference(failedFile, thumbs) - } - - return failed -} - -// GroupFileByPolicy 将目标文件按照存储策略分组 -func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) map[uint][]*model.File { - var policyGroup = make(map[uint][]*model.File) - - for key := range files { - if file, ok := policyGroup[files[key].PolicyID]; ok { - // 如果已存在分组,直接追加 - policyGroup[files[key].PolicyID] = append(file, &files[key]) - } else { - // 分组不存在,创建 - policyGroup[files[key].PolicyID] = make([]*model.File, 0) - policyGroup[files[key].PolicyID] = append(policyGroup[files[key].PolicyID], &files[key]) - } - } - - return policyGroup -} - -// GetDownloadURL 创建文件下载链接, timeout 为数据库中存储过期时间的字段 -func (fs *FileSystem) GetDownloadURL(ctx context.Context, id uint, timeout string) (string, error) { - err := fs.resetFileIDIfNotExist(ctx, id) - if err != nil { - return "", err - } - fileTarget := &fs.FileTarget[0] - - // 生成下載地址 - ttl := model.GetIntSetting(timeout, 60) - source, err := fs.SignURL( - ctx, - fileTarget, - int64(ttl), - true, - ) - if err != nil { - return "", err - } - - return source, nil -} - -// GetSource 获取可直接访问文件的外链地址 -func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) { - // 查找文件记录 - err := fs.resetFileIDIfNotExist(ctx, fileID) - if err != nil { - return "", ErrObjectNotExist.WithError(err) - } - - // 检查存储策略是否可以获得外链 - if !fs.Policy.IsOriginLinkEnable { - return "", serializer.NewError( - serializer.CodePolicyNotAllowed, - "This policy is not enabled for getting source link", - nil, - ) - } - - source, err := fs.SignURL(ctx, &fs.FileTarget[0], 0, false) - if err != nil { - return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err) - } - - return source, nil -} - -// SignURL 签名文件原始 URL -func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) { - fs.FileTarget = []model.File{*file} - ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file) - - err := fs.resetPolicyToFirstFile(ctx) - if err != nil { - return "", err - } - - // 签名最终URL - // 生成外链地址 - source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit) - if err != nil { - return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err) - } - - return source, nil -} - -// ResetFileIfNotExist 重设当前目标文件为 path,如果当前目标为空 -func (fs *FileSystem) ResetFileIfNotExist(ctx context.Context, path string) error { - // 找到文件 - if len(fs.FileTarget) == 0 { - exist, file := fs.IsFileExist(path) - if !exist { - return ErrObjectNotExist - } - fs.FileTarget = []model.File{*file} - } - - // 将当前存储策略重设为文件使用的 - return fs.resetPolicyToFirstFile(ctx) -} - -// ResetFileIfNotExist 重设当前目标文件为 id,如果当前目标为空 -func (fs *FileSystem) resetFileIDIfNotExist(ctx context.Context, id uint) error { - // 找到文件 - if len(fs.FileTarget) == 0 { - file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID) - if err != nil || len(file) == 0 { - return ErrObjectNotExist - } - fs.FileTarget = []model.File{file[0]} - } - - // 如果上下文限制了父目录,则进行检查 - if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok { - if parent.ID != fs.FileTarget[0].FolderID { - return ErrObjectNotExist - } - } - - // 将当前存储策略重设为文件使用的 - return fs.resetPolicyToFirstFile(ctx) -} - -// resetPolicyToFirstFile 将当前存储策略重设为第一个目标文件文件使用的 -func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error { - if len(fs.FileTarget) == 0 { - return ErrObjectNotExist - } - - // 从机模式不进行操作 - if conf.SystemConfig.Mode == "slave" { - return nil - } - - fs.Policy = fs.FileTarget[0].GetPolicy() - err := fs.DispatchHandler() - if err != nil { - return err - } - return nil -} - -// Search 搜索文件 -func (fs *FileSystem) Search(ctx context.Context, keywords ...interface{}) ([]serializer.Object, error) { - parents := make([]uint, 0) - - // 如果限定了根目录,则只在这个根目录下搜索。 - if fs.Root != nil { - allFolders, err := model.GetRecursiveChildFolder([]uint{fs.Root.ID}, fs.User.ID, true) - if err != nil { - return nil, fmt.Errorf("failed to list all folders: %w", err) - } - - for _, folder := range allFolders { - parents = append(parents, folder.ID) - } - } - - files, _ := model.GetFilesByKeywords(fs.User.ID, parents, keywords...) - fs.SetTargetFile(&files) - - return fs.listObjects(ctx, "/", files, nil, nil), nil -} diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go deleted file mode 100644 index 66f34449..00000000 --- a/pkg/filesystem/file_test.go +++ /dev/null @@ -1,669 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "os" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_AddFile(t *testing.T) { - asserts := assert.New(t) - file := fsctx.FileStream{ - Size: 5, - Name: "1.png", - SavePath: "/Uploads/1_sad.png", - } - folder := model.Folder{ - Model: gorm.Model{ - ID: 1, - }, - } - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Type: "cos", - Model: gorm.Model{ - ID: 1, - }, - }, - }, - Policy: &model.Policy{Type: "cos"}, - } - - _, err := fs.AddFile(context.Background(), &folder, &file) - - asserts.Error(err) - - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - f, err := fs.AddFile(context.Background(), &folder, &file) - - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("/Uploads/1_sad.png", f.SourceName) - - // 前置钩子执行失败 - { - hookExecuted := false - fs.Use("BeforeAddFile", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - hookExecuted = true - return errors.New("error") - }) - f, err := fs.AddFile(context.Background(), &folder, &file) - asserts.Error(err) - asserts.Nil(f) - asserts.True(hookExecuted) - } - - // 后置钩子执行失败 - { - hookExecuted := false - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - fs.Hooks = map[string][]Hook{} - fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - hookExecuted = true - return errors.New("error") - }) - f, err := fs.AddFile(context.Background(), &folder, &file) - asserts.Error(err) - asserts.Nil(f) - asserts.True(hookExecuted) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_GetContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Model: gorm.Model{ - ID: 1, - }, - }, - }, - } - - // 文件不存在 - rs, err := fs.GetContent(ctx, 1) - asserts.Equal(ErrObjectNotExist, err) - asserts.Nil(rs) - fs.CleanTargets() - - // 未知存储策略 - file, err := os.Create(util.RelativePath("TestFileSystem_GetContent.txt")) - asserts.NoError(err) - _ = file.Close() - - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent.txt", 1)) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "unknown")) - - rs, err = fs.GetContent(ctx, 1) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 打开文件失败 - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent2.txt", 1)) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "source_name"}).AddRow(1, "local", "not exist")) - - rs, err = fs.GetContent(ctx, 1) - asserts.Equal(serializer.CodeIOFailed, err.(serializer.AppError).Code) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 打开成功 - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetContent.txt", 1, "TestFileSystem_GetContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - rs, err = fs.GetContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_GetDownloadContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Model: gorm.Model{ - ID: 599, - }, - }, - }, - } - file, err := os.Create(util.RelativePath("TestFileSystem_GetDownloadContent.txt")) - asserts.NoError(err) - _ = file.Close() - - cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - // 无限速 - cache.Deletes([]string{"599"}, "policy_") - _, err = fs.GetDownloadContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 有限速 - cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - fs.User.Group.SpeedLimit = 1 - _, err = fs.GetDownloadContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_GroupFileByPolicy(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - files := []model.File{ - model.File{ - PolicyID: 1, - Name: "1_1.txt", - }, - model.File{ - PolicyID: 2, - Name: "2_1.txt", - }, - model.File{ - PolicyID: 3, - Name: "3_1.txt", - }, - model.File{ - PolicyID: 2, - Name: "2_2.txt", - }, - model.File{ - PolicyID: 1, - Name: "1_2.txt", - }, - } - fs := FileSystem{} - policyGroup := fs.GroupFileByPolicy(ctx, files) - asserts.Equal(map[uint][]*model.File{ - 1: {&files[0], &files[4]}, - 2: {&files[1], &files[3]}, - 3: {&files[2]}, - }, policyGroup) -} - -func TestFileSystem_deleteGroupedFile(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{} - files := []model.File{ - { - PolicyID: 1, - Name: "1_1.txt", - SourceName: "1_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 2, - Name: "2_1.txt", - SourceName: "2_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 3, - Name: "3_1.txt", - SourceName: "3_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 2, - Name: "2_2.txt", - SourceName: "2_2.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 1, - Name: "1_2.txt", - SourceName: "1_2.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - } - - // 全部不存在 - { - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } - // 部分不存在 - { - file, err := os.Create(util.RelativePath("1_1.txt")) - asserts.NoError(err) - _ = file.Close() - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } - // 部分失败,包含整组未知存储策略导致的失败 - { - file, err := os.Create(util.RelativePath("1_1.txt")) - asserts.NoError(err) - _ = file.Close() - - files[1].Policy.Type = "unknown" - files[3].Policy.Type = "unknown" - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {"2_1.txt", "2_2.txt"}, - 3: {}, - }, failed) - } - // 包含上传会话文件 - { - sessionID := "session" - cache.Set(UploadSessionCachePrefix+sessionID, serializer.UploadSession{Key: sessionID}, 0) - files[1].Policy.Type = "local" - files[3].Policy.Type = "local" - files[0].UploadSessionID = &sessionID - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - _, ok := cache.Get(UploadSessionCachePrefix + sessionID) - asserts.False(ok) - } - - // 包含缩略图 - { - files[0].MetadataSerialized = map[string]string{ - model.ThumbSidecarMetadataKey: "1", - } - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } -} - -func TestFileSystem_GetSource(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("123")} - - // 正常 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 35, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "local", true), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - fs.CleanTargets() - } - - // 文件不存在 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrObjectNotExist.Code, err.(serializer.AppError).Code) - asserts.Empty(sourceURL) - fs.CleanTargets() - } - - // 未知上传策略 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 36, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(36, "?", true), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(sourceURL) - fs.CleanTargets() - } - - // 不允许获取外链 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 37, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(37, "local", false), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(serializer.CodePolicyNotAllowed, err.(serializer.AppError).Code) - asserts.Empty(sourceURL) - fs.CleanTargets() - } -} - -func TestFileSystem_GetDownloadURL(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - auth.General = auth.HMACAuth{SecretKey: []byte("123")} - - // 正常 - { - err := cache.Deletes([]string{"35"}, "policy_") - cache.Set("setting_download_timeout", "20", 0) - cache.Set("setting_siteURL", "https://cloudreve.org", 0) - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "local", true), - ) - // 相关设置 - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotEmpty(downloadURL) - fs.CleanTargets() - } - - // 文件不存在 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - err = cache.Deletes([]string{"35"}, "policy_") - err = cache.Deletes([]string{"download_timeout"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"})) - - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(downloadURL) - fs.CleanTargets() - } - - // 未知存储策略 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - err = cache.Deletes([]string{"35"}, "policy_") - err = cache.Deletes([]string{"download_timeout"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "unknown", true), - ) - - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(downloadURL) - fs.CleanTargets() - } -} - -func TestFileSystem_GetPhysicalFileContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - } - - // 文件不存在 - { - rs, err := fs.GetPhysicalFileContent(ctx, "not_exist.txt") - asserts.Error(err) - asserts.Nil(rs) - } - - // 成功 - { - testFile, err := os.Create(util.RelativePath("GetPhysicalFileContent.txt")) - asserts.NoError(err) - asserts.NoError(testFile.Close()) - - rs, err := fs.GetPhysicalFileContent(ctx, "GetPhysicalFileContent.txt") - asserts.NoError(err) - asserts.NoError(rs.Close()) - asserts.NotNil(rs) - } -} - -func TestFileSystem_Preview(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - - // 文件不存在 - { - fs := FileSystem{ - User: &model.User{}, - } - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - resp, err := fs.Preview(ctx, 1, false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(resp) - } - - // 直接返回文件内容,找不到文件 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/no.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, - }, - } - resp, err := fs.Preview(ctx, 1, false) - asserts.Error(err) - asserts.Nil(resp) - } - - // 直接返回文件内容 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, - }, - } - resp, err := fs.Preview(ctx, 1, false) - asserts.Error(err) - asserts.Nil(resp) - } - - // 需要重定向,成功 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "remote", - }, - }, - } - asserts.NoError(cache.Set("setting_preview_timeout", "233", 0)) - resp, err := fs.Preview(ctx, 1, false) - asserts.NoError(err) - asserts.NotNil(resp) - asserts.True(resp.Redirect) - } - - // 文本文件,大小超出限制 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "remote", - }, - Size: 11, - }, - } - asserts.NoError(cache.Set("setting_maxEditSize", "10", 0)) - resp, err := fs.Preview(ctx, 1, true) - asserts.Equal(ErrFileSizeTooBig, err) - asserts.Nil(resp) - } -} - -func TestFileSystem_ResetFileIDIfNotExist(t *testing.T) { - asserts := assert.New(t) - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{Model: gorm.Model{ID: 1}}) - fs := FileSystem{ - FileTarget: []model.File{ - { - FolderID: 2, - }, - }, - } - asserts.Equal(ErrObjectNotExist, fs.resetFileIDIfNotExist(ctx, 1)) -} - -func TestFileSystem_Search(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := &FileSystem{ - User: &model.User{}, - } - fs.User.ID = 1 - - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := fs.Search(ctx, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) -} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go deleted file mode 100644 index 1e14fa81..00000000 --- a/pkg/filesystem/filesystem.go +++ /dev/null @@ -1,292 +0,0 @@ -package filesystem - -import ( - "errors" - "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/qiniu" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" - cossdk "github.com/tencentyun/cos-go-sdk-v5" - "net/http" - "net/url" - "sync" -) - -// FSPool 文件系统资源池 -var FSPool = sync.Pool{ - New: func() interface{} { - return &FileSystem{} - }, -} - -// FileSystem 管理文件的文件系统 -type FileSystem struct { - // 文件系统所有者 - User *model.User - // 操作文件使用的存储策略 - Policy *model.Policy - // 当前正在处理的文件对象 - FileTarget []model.File - // 当前正在处理的目录对象 - DirTarget []model.Folder - // 相对根目录 - Root *model.Folder - // 互斥锁 - Lock sync.Mutex - - /* - 钩子函数 - */ - Hooks map[string][]Hook - - /* - 文件系统处理适配器 - */ - Handler driver.Handler - - // 回收锁 - recycleLock sync.Mutex -} - -// getEmptyFS 从pool中获取新的FileSystem -func getEmptyFS() *FileSystem { - fs := FSPool.Get().(*FileSystem) - return fs -} - -// Recycle 回收FileSystem资源 -func (fs *FileSystem) Recycle() { - fs.recycleLock.Lock() - fs.reset() - FSPool.Put(fs) -} - -// reset 重设文件系统,以便回收使用 -func (fs *FileSystem) reset() { - fs.User = nil - fs.CleanTargets() - fs.Policy = nil - fs.Hooks = nil - fs.Handler = nil - fs.Root = nil - fs.Lock = sync.Mutex{} - fs.recycleLock = sync.Mutex{} -} - -// NewFileSystem 初始化一个文件系统 -func NewFileSystem(user *model.User) (*FileSystem, error) { - fs := getEmptyFS() - fs.User = user - fs.Policy = &fs.User.Policy - - // 分配存储策略适配器 - err := fs.DispatchHandler() - - return fs, err -} - -// NewAnonymousFileSystem 初始化匿名文件系统 -func NewAnonymousFileSystem() (*FileSystem, error) { - fs := getEmptyFS() - fs.User = &model.User{} - - // 如果是主机模式下,则为匿名文件系统分配游客用户组 - if conf.SystemConfig.Mode == "master" { - anonymousGroup, err := model.GetGroupByID(3) - if err != nil { - return nil, err - } - fs.User.Group = anonymousGroup - } else { - // 从机模式下,分配本地策略处理器 - fs.Handler = local.Driver{} - } - - return fs, nil -} - -// DispatchHandler 根据存储策略分配文件适配器 -func (fs *FileSystem) DispatchHandler() error { - if fs.Policy == nil { - return errors.New("未设置存储策略") - } - policyType := fs.Policy.Type - currentPolicy := fs.Policy - - switch policyType { - case "mock", "anonymous": - return nil - case "local": - fs.Handler = local.Driver{ - Policy: currentPolicy, - } - return nil - case "remote": - handler, err := remote.NewDriver(currentPolicy) - if err != nil { - return err - } - - fs.Handler = handler - case "qiniu": - fs.Handler = qiniu.NewDriver(currentPolicy) - return nil - case "oss": - handler, err := oss.NewDriver(currentPolicy) - fs.Handler = handler - return err - case "upyun": - fs.Handler = upyun.Driver{ - Policy: currentPolicy, - } - return nil - case "onedrive": - var odErr error - fs.Handler, odErr = onedrive.NewDriver(currentPolicy) - return odErr - case "cos": - u, _ := url.Parse(currentPolicy.Server) - b := &cossdk.BaseURL{BucketURL: u} - fs.Handler = cos.Driver{ - Policy: currentPolicy, - Client: cossdk.NewClient(b, &http.Client{ - Transport: &cossdk.AuthorizationTransport{ - SecretID: currentPolicy.AccessKey, - SecretKey: currentPolicy.SecretKey, - }, - }), - HTTPClient: request.NewClient(), - } - return nil - case "s3": - handler, err := s3.NewDriver(currentPolicy) - fs.Handler = handler - return err - case "googledrive": - handler, err := googledrive.NewDriver(currentPolicy) - fs.Handler = handler - return err - default: - return ErrUnknownPolicyType - } - - return nil -} - -// NewFileSystemFromContext 从gin.Context创建文件系统 -func NewFileSystemFromContext(c *gin.Context) (*FileSystem, error) { - user, exist := c.Get("user") - if !exist { - return NewAnonymousFileSystem() - } - fs, err := NewFileSystem(user.(*model.User)) - return fs, err -} - -// NewFileSystemFromCallback 从gin.Context创建回调用文件系统 -func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) { - fs, err := NewFileSystemFromContext(c) - if err != nil { - return nil, err - } - - // 获取回调会话 - callbackSessionRaw, ok := c.Get(UploadSessionCtx) - if !ok { - return nil, errors.New("upload session not exist") - } - callbackSession := callbackSessionRaw.(*serializer.UploadSession) - - // 重新指向上传策略 - fs.Policy = &callbackSession.Policy - err = fs.DispatchHandler() - - return fs, err -} - -// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点 -func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) { - fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, fs.Policy) -} - -// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器 -func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) { - switch fs.Policy.Type { - case "local": - fs.Policy.Type = "remote" - fs.Policy.Server = masterURL - fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID()) - fs.Policy.SecretKey = master.DBModel().MasterKey - fs.DispatchHandler() - case "onedrive": - fs.Policy.MasterID = masterID - } - - fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy) -} - -// SetTargetFile 设置当前处理的目标文件 -func (fs *FileSystem) SetTargetFile(files *[]model.File) { - if len(fs.FileTarget) == 0 { - fs.FileTarget = *files - } else { - fs.FileTarget = append(fs.FileTarget, *files...) - } - -} - -// SetTargetDir 设置当前处理的目标目录 -func (fs *FileSystem) SetTargetDir(dirs *[]model.Folder) { - if len(fs.DirTarget) == 0 { - fs.DirTarget = *dirs - } else { - fs.DirTarget = append(fs.DirTarget, *dirs...) - } - -} - -// SetTargetFileByIDs 根据文件ID设置目标文件,忽略用户ID -func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error { - files, err := model.GetFilesByIDs(ids, 0) - if err != nil || len(files) == 0 { - return ErrFileExisted.WithError(err) - } - fs.SetTargetFile(&files) - return nil -} - -// SetTargetByInterface 根据 model.File 或者 model.Folder 设置目标对象 -// TODO 测试 -func (fs *FileSystem) SetTargetByInterface(target interface{}) error { - if file, ok := target.(*model.File); ok { - fs.SetTargetFile(&[]model.File{*file}) - return nil - } - if folder, ok := target.(*model.Folder); ok { - fs.SetTargetDir(&[]model.Folder{*folder}) - return nil - } - - return ErrObjectNotExist -} - -// CleanTargets 清空目标 -func (fs *FileSystem) CleanTargets() { - fs.FileTarget = fs.FileTarget[:0] - fs.DirTarget = fs.DirTarget[:0] -} diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go deleted file mode 100644 index 8b7aae37..00000000 --- a/pkg/filesystem/filesystem_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package filesystem - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "net/http/httptest" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - - "testing" -) - -func TestNewFileSystem(t *testing.T) { - asserts := assert.New(t) - user := model.User{ - Policy: model.Policy{ - Type: "local", - }, - } - - // 本地 成功 - fs, err := NewFileSystem(&user) - asserts.NoError(err) - asserts.NotNil(fs.Handler) - asserts.IsType(local.Driver{}, fs.Handler) - // 远程 - user.Policy.Type = "remote" - fs, err = NewFileSystem(&user) - asserts.NoError(err) - asserts.NotNil(fs.Handler) - asserts.IsType(&remote.Driver{}, fs.Handler) - - user.Policy.Type = "unknown" - fs, err = NewFileSystem(&user) - asserts.Error(err) -} - -func TestNewFileSystemFromContext(t *testing.T) { - asserts := assert.New(t) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - fs, err := NewFileSystemFromContext(c) - asserts.NotNil(fs) - asserts.NoError(err) - - c, _ = gin.CreateTestContext(httptest.NewRecorder()) - fs, err = NewFileSystemFromContext(c) - asserts.Nil(fs) - asserts.Error(err) -} - -func TestDispatchHandler(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - Type: "local", - }, - } - - // 未指定,使用用户默认 - err := fs.DispatchHandler() - asserts.NoError(err) - asserts.IsType(local.Driver{}, fs.Handler) - - // 已指定,发生错误 - fs.Policy = &model.Policy{Type: "unknown"} - err = fs.DispatchHandler() - asserts.Error(err) - - fs.Policy = &model.Policy{Type: "mock"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "local"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "remote"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "qiniu"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "oss", Server: "https://s.com", BucketName: "1234"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "upyun"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "onedrive"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "cos"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "s3"} - err = fs.DispatchHandler() - asserts.NoError(err) -} - -func TestNewFileSystemFromCallback(t *testing.T) { - asserts := assert.New(t) - - // 用户上下文不存在 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - fs, err := NewFileSystemFromCallback(c) - asserts.Nil(fs) - asserts.Error(err) - } - - // 找不到回调会话 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - fs, err := NewFileSystemFromCallback(c) - asserts.Nil(fs) - asserts.Error(err) - } - - // 成功 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - c.Set(UploadSessionCtx, &serializer.UploadSession{Policy: model.Policy{Type: "local"}}) - fs, err := NewFileSystemFromCallback(c) - asserts.NotNil(fs) - asserts.NoError(err) - } - -} - -func TestFileSystem_SetTargetFileByIDs(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - fs := &FileSystem{} - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - err := fs.SetTargetFileByIDs([]uint{1, 2}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(fs.FileTarget, 1) - asserts.NoError(err) - } - - // 未找到 - { - fs := &FileSystem{} - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.SetTargetFileByIDs([]uint{1, 2}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(fs.FileTarget, 0) - asserts.Error(err) - } -} - -func TestFileSystem_CleanTargets(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - FileTarget: []model.File{{}, {}}, - DirTarget: []model.Folder{{}, {}}, - } - - fs.CleanTargets() - asserts.Len(fs.FileTarget, 0) - asserts.Len(fs.DirTarget, 0) -} - -func TestNewAnonymousFileSystem(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"}).AddRow(3, "游客", "[]")) - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("游客", fs.User.Group.Name) - } - - // 游客用户组不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"})) - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(fs) - } - - // 从机 - { - conf.SystemConfig.Mode = "slave" - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(fs) - asserts.NotNil(fs.Handler) - } -} - -func TestFileSystem_Recycle(t *testing.T) { - fs := &FileSystem{ - User: &model.User{}, - Policy: &model.Policy{}, - FileTarget: []model.File{model.File{}}, - DirTarget: []model.Folder{model.Folder{}}, - Hooks: map[string][]Hook{"AfterUpload": []Hook{GenericAfterUpdate}}, - } - fs.Recycle() - newFS := getEmptyFS() - if fs != newFS { - t.Error("指针不一致") - } -} - -func TestFileSystem_SetTargetByInterface(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{} - - // 目录 - { - asserts.NoError(fs.SetTargetByInterface(&model.Folder{})) - asserts.Len(fs.DirTarget, 1) - asserts.Len(fs.FileTarget, 0) - } - - // 文件 - { - asserts.NoError(fs.SetTargetByInterface(&model.File{})) - asserts.Len(fs.DirTarget, 1) - asserts.Len(fs.FileTarget, 1) - } -} - -func TestFileSystem_SwitchToSlaveHandler(t *testing.T) { - a := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - } - mockNode := &cluster.MasterNode{ - Model: &model.Node{}, - } - fs.SwitchToSlaveHandler(mockNode) - a.IsType(&slaveinmaster.Driver{}, fs.Handler) -} - -func TestFileSystem_SwitchToShadowHandler(t *testing.T) { - a := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{}, - } - mockNode := &cluster.MasterNode{ - Model: &model.Node{}, - } - - // local to remote - { - fs.Policy.Type = "local" - fs.SwitchToShadowHandler(mockNode, "", "") - a.IsType(&masterinslave.Driver{}, fs.Handler) - } - - // onedrive - { - fs.Policy.Type = "onedrive" - fs.SwitchToShadowHandler(mockNode, "", "") - a.IsType(&masterinslave.Driver{}, fs.Handler) - } -} diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go deleted file mode 100644 index 1b7b3be2..00000000 --- a/pkg/filesystem/fsctx/context.go +++ /dev/null @@ -1,44 +0,0 @@ -package fsctx - -type key int - -const ( - // GinCtx Gin的上下文 - GinCtx key = iota - // PathCtx 文件或目录的虚拟路径 - PathCtx - // FileModelCtx 文件数据库模型 - FileModelCtx - // FolderModelCtx 目录数据库模型 - FolderModelCtx - // HTTPCtx HTTP请求的上下文 - HTTPCtx - // UploadPolicyCtx 上传策略,一般为slave模式下使用 - UploadPolicyCtx - // UserCtx 用户 - UserCtx - // ThumbSizeCtx 缩略图尺寸 - ThumbSizeCtx - // FileSizeCtx 文件大小 - FileSizeCtx - // ShareKeyCtx 分享文件的 HashID - ShareKeyCtx - // LimitParentCtx 限制父目录 - LimitParentCtx - // IgnoreDirectoryConflictCtx 忽略目录重名冲突 - IgnoreDirectoryConflictCtx - // RetryCtx 失败重试次数 - RetryCtx - // ForceUsePublicEndpointCtx 强制使用公网 Endpoint - ForceUsePublicEndpointCtx - // CancelFuncCtx Context 取消函數 - CancelFuncCtx - // 文件在从机节点中的路径 - SlaveSrcPath - // Webdav目标名称 - WebdavDstName - // WebDAVCtx WebDAV - WebDAVCtx - // WebDAV反代Url - WebDAVProxyUrlCtx -) diff --git a/pkg/filesystem/fsctx/stream.go b/pkg/filesystem/fsctx/stream.go deleted file mode 100644 index 512270b4..00000000 --- a/pkg/filesystem/fsctx/stream.go +++ /dev/null @@ -1,123 +0,0 @@ -package fsctx - -import ( - "errors" - "github.com/HFO4/aliyun-oss-go-sdk/oss" - "io" - "time" -) - -type WriteMode int - -const ( - Overwrite WriteMode = 0x00001 - // Append 只适用于本地策略 - Append WriteMode = 0x00002 - Nop WriteMode = 0x00004 -) - -type UploadTaskInfo struct { - Size uint64 - MimeType string - FileName string - VirtualPath string - Mode WriteMode - Metadata map[string]string - LastModified *time.Time - SavePath string - UploadSessionID *string - AppendStart uint64 - Model interface{} - Src string -} - -// Get mimetype of uploaded file, if it's not defined, detect it from file name -func (u *UploadTaskInfo) DetectMimeType() string { - if u.MimeType != "" { - return u.MimeType - } - - return oss.TypeByExtension(u.FileName) -} - -// FileHeader 上传来的文件数据处理器 -type FileHeader interface { - io.Reader - io.Closer - io.Seeker - Info() *UploadTaskInfo - SetSize(uint64) - SetModel(fileModel interface{}) - Seekable() bool -} - -// FileStream 用户传来的文件 -type FileStream struct { - Mode WriteMode - LastModified *time.Time - Metadata map[string]string - File io.ReadCloser - Seeker io.Seeker - Size uint64 - VirtualPath string - Name string - MimeType string - SavePath string - UploadSessionID *string - AppendStart uint64 - Model interface{} - Src string -} - -func (file *FileStream) Read(p []byte) (n int, err error) { - if file.File != nil { - return file.File.Read(p) - } - - return 0, io.EOF -} - -func (file *FileStream) Close() error { - if file.File != nil { - return file.File.Close() - } - - return nil -} - -func (file *FileStream) Seek(offset int64, whence int) (int64, error) { - if file.Seekable() { - return file.Seeker.Seek(offset, whence) - } - - return 0, errors.New("no seeker") -} - -func (file *FileStream) Seekable() bool { - return file.Seeker != nil -} - -func (file *FileStream) Info() *UploadTaskInfo { - return &UploadTaskInfo{ - Size: file.Size, - MimeType: file.MimeType, - FileName: file.Name, - VirtualPath: file.VirtualPath, - Mode: file.Mode, - Metadata: file.Metadata, - LastModified: file.LastModified, - SavePath: file.SavePath, - UploadSessionID: file.UploadSessionID, - AppendStart: file.AppendStart, - Model: file.Model, - Src: file.Src, - } -} - -func (file *FileStream) SetSize(size uint64) { - file.Size = size -} - -func (file *FileStream) SetModel(fileModel interface{}) { - file.Model = fileModel -} diff --git a/pkg/filesystem/fsctx/stream_test.go b/pkg/filesystem/fsctx/stream_test.go deleted file mode 100644 index 1ef6e1fa..00000000 --- a/pkg/filesystem/fsctx/stream_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package fsctx - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "io" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestFileStream_Read(t *testing.T) { - asserts := assert.New(t) - file := FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - } - var p = make([]byte, 3) - { - n, err := file.Read(p) - asserts.Equal(3, n) - asserts.NoError(err) - } -} - -func TestFileStream_Close(t *testing.T) { - asserts := assert.New(t) - { - file := FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - } - err := file.Close() - asserts.NoError(err) - } - - { - file := FileStream{} - err := file.Close() - asserts.NoError(err) - } -} - -func TestFileStream_Seek(t *testing.T) { - asserts := assert.New(t) - f, _ := os.CreateTemp("", "*") - defer func() { - f.Close() - os.Remove(f.Name()) - }() - { - file := FileStream{ - File: f, - Seeker: f, - } - res, err := file.Seek(0, io.SeekStart) - asserts.NoError(err) - asserts.EqualValues(0, res) - } - - { - file := FileStream{} - res, err := file.Seek(0, io.SeekStart) - asserts.Error(err) - asserts.EqualValues(0, res) - } -} - -func TestFileStream_Info(t *testing.T) { - a := assert.New(t) - file := FileStream{} - a.NotNil(file.Info()) - - file.SetSize(10) - a.EqualValues(10, file.Info().Size) - - file.SetModel(&model.File{}) - a.NotNil(file.Info().Model) -} diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go deleted file mode 100644 index a2f9ed5f..00000000 --- a/pkg/filesystem/hooks.go +++ /dev/null @@ -1,304 +0,0 @@ -package filesystem - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io/ioutil" - "net/http" - "strconv" - "strings" - "time" -) - -// Hook 钩子函数 -type Hook func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error - -// Use 注入钩子 -func (fs *FileSystem) Use(name string, hook Hook) { - if fs.Hooks == nil { - fs.Hooks = make(map[string][]Hook) - } - if _, ok := fs.Hooks[name]; ok { - fs.Hooks[name] = append(fs.Hooks[name], hook) - return - } - fs.Hooks[name] = []Hook{hook} -} - -// CleanHooks 清空钩子,name为空表示全部清空 -func (fs *FileSystem) CleanHooks(name string) { - if name == "" { - fs.Hooks = nil - } else { - delete(fs.Hooks, name) - } -} - -// Trigger 触发钩子,遇到第一个错误时 -// 返回错误,后续钩子不会继续执行 -func (fs *FileSystem) Trigger(ctx context.Context, name string, file fsctx.FileHeader) error { - if hooks, ok := fs.Hooks[name]; ok { - for _, hook := range hooks { - err := hook(ctx, fs, file) - if err != nil { - util.Log().Warning("Failed to execute hook:%s", err) - return err - } - } - } - return nil -} - -// HookValidateFile 一系列对文件检验的集合 -func HookValidateFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - fileInfo := file.Info() - - // 验证单文件尺寸 - if !fs.ValidateFileSize(ctx, fileInfo.Size) { - return ErrFileSizeTooBig - } - - // 验证文件名 - if !fs.ValidateLegalName(ctx, fileInfo.FileName) { - return ErrIllegalObjectName - } - - // 验证扩展名 - if !fs.ValidateExtension(ctx, fileInfo.FileName) { - return ErrFileExtensionNotAllowed - } - - return nil - -} - -// HookResetPolicy 重设存储策略为上下文已有文件 -func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return ErrObjectNotExist - } - - fs.Policy = originFile.GetPolicy() - return fs.DispatchHandler() -} - -// HookValidateCapacity 验证用户容量 -func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - // 验证并扣除容量 - if fs.User.GetRemainingCapacity() < file.Info().Size { - return ErrInsufficientCapacity - } - return nil -} - -// HookValidateCapacityDiff 根据原有文件和新文件的大小验证用户容量 -func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { - originFile := ctx.Value(fsctx.FileModelCtx).(model.File) - newFileSize := newFile.Info().Size - - if newFileSize > originFile.Size { - return HookValidateCapacity(ctx, fs, newFile) - } - - return nil -} - -// HookDeleteTempFile 删除已保存的临时文件 -func HookDeleteTempFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - // 删除临时文件 - _, err := fs.Handler.Delete(ctx, []string{file.Info().SavePath}) - if err != nil { - util.Log().Warning("Failed to clean-up temp files: %s", err) - } - - return nil -} - -// HookCleanFileContent 清空文件内容 -func HookCleanFileContent(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - // 清空内容 - return fs.Handler.Put(ctx, &fsctx.FileStream{ - File: ioutil.NopCloser(strings.NewReader("")), - SavePath: file.Info().SavePath, - Size: 0, - Mode: fsctx.Overwrite, - }) -} - -// HookClearFileSize 将原始文件的尺寸设为0 -func HookClearFileSize(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return ErrObjectNotExist - } - return originFile.UpdateSize(0) -} - -// HookCancelContext 取消上下文 -func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc) - if ok { - cancelFunc() - } - return nil -} - -// HookUpdateSourceName 更新文件SourceName -func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return ErrObjectNotExist - } - return originFile.UpdateSourceName(originFile.SourceName) -} - -// GenericAfterUpdate 文件内容更新后 -func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { - // 更新文件尺寸 - originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return ErrObjectNotExist - } - - newFile.SetModel(&originFile) - - err := originFile.UpdateSize(newFile.Info().Size) - if err != nil { - return err - } - - return nil -} - -// SlaveAfterUpload Slave模式下上传完成钩子 -func SlaveAfterUpload(session *serializer.UploadSession) Hook { - return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - if session.Callback == "" { - return nil - } - - // 发送回调请求 - callbackBody := serializer.UploadCallback{} - return cluster.RemoteCallback(session.Callback, callbackBody) - } -} - -// GenericAfterUpload 文件上传完成后,包含数据库操作 -func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileInfo := fileHeader.Info() - - // 创建或查找根目录 - folder, err := fs.CreateDirectory(ctx, fileInfo.VirtualPath) - if err != nil { - return err - } - - // 检查文件是否存在 - if ok, file := fs.IsChildFileExist( - folder, - fileInfo.FileName, - ); ok { - if file.UploadSessionID != nil { - return ErrFileUploadSessionExisted - } - - return ErrFileExisted - } - - // 向数据库中插入记录 - file, err := fs.AddFile(ctx, folder, fileHeader) - if err != nil { - return ErrInsertFileRecord - } - fileHeader.SetModel(file) - - return nil -} - -// HookClearFileHeaderSize 将FileHeader大小设定为0 -func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileHeader.SetSize(0) - return nil -} - -// HookTruncateFileTo 将物理文件截断至 size -func HookTruncateFileTo(size uint64) Hook { - return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - if handler, ok := fs.Handler.(local.Driver); ok { - return handler.Truncate(ctx, fileHeader.Info().SavePath, size) - } - - return nil - } -} - -// HookChunkUploadFinished 单个分片上传结束后 -func HookChunkUploaded(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileInfo := fileHeader.Info() - - // 更新文件大小 - return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart + fileInfo.Size) -} - -// HookChunkUploadFailed 单个分片上传失败后 -func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileInfo := fileHeader.Info() - - // 更新文件大小 - return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart) -} - -// HookPopPlaceholderToFile 将占位文件提升为正式文件 -func HookPopPlaceholderToFile(picInfo string) Hook { - return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fileInfo := fileHeader.Info() - fileModel := fileInfo.Model.(*model.File) - return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo) - } -} - -// HookChunkUploadFinished 分片上传结束后处理文件 -func HookDeleteUploadSession(id string) Hook { - return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - cache.Deletes([]string{id}, UploadSessionCachePrefix) - return nil - } -} - -// NewWebdavAfterUploadHook 每次创建一个新的钩子函数 rclone 在 PUT 请求里有 OC-Checksum 字符串 -// 和 X-OC-Mtime -func NewWebdavAfterUploadHook(request *http.Request) func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { - var modtime time.Time - if timeVal := request.Header.Get("X-OC-Mtime"); timeVal != "" { - timeUnix, err := strconv.ParseInt(timeVal, 10, 64) - if err == nil { - modtime = time.Unix(timeUnix, 0) - } - } - checksum := request.Header.Get("OC-Checksum") - - return func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error { - file := newFile.Info().Model.(*model.File) - if !modtime.IsZero() { - err := model.DB.Model(file).UpdateColumn("updated_at", modtime).Error - if err != nil { - return err - } - } - - if checksum != "" { - return file.UpdateMetadata(map[string]string{ - model.ChecksumMetadataKey: checksum, - }) - } - - return nil - } -} diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go deleted file mode 100644 index cc660ce9..00000000 --- a/pkg/filesystem/hooks_test.go +++ /dev/null @@ -1,708 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io/ioutil" - "net/http" - "strings" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestGenericBeforeUpload(t *testing.T) { - asserts := assert.New(t) - file := &fsctx.FileStream{ - Size: 5, - Name: "1.txt", - } - ctx := context.Background() - cache.Set("pack_size_0", uint64(0), 0) - fs := FileSystem{ - User: &model.User{ - Storage: 0, - Group: model.Group{ - MaxStorage: 11, - }, - }, - Policy: &model.Policy{ - MaxSize: 4, - OptionsSerialized: model.PolicyOption{ - FileType: []string{"txt"}, - }, - }, - } - - asserts.Error(HookValidateFile(ctx, &fs, file)) - - file.Size = 1 - file.Name = "1" - asserts.Error(HookValidateFile(ctx, &fs, file)) - - file.Name = "1.txt" - asserts.NoError(HookValidateFile(ctx, &fs, file)) - - file.Name = "1.t/xt" - asserts.Error(HookValidateFile(ctx, &fs, file)) -} - -func TestGenericAfterUploadCanceled(t *testing.T) { - asserts := assert.New(t) - file := &fsctx.FileStream{ - Size: 5, - Name: "TestGenericAfterUploadCanceled", - SavePath: "TestGenericAfterUploadCanceled", - } - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - } - - // 成功 - { - mockHandler := &FileHeaderMock{} - fs.Handler = mockHandler - mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, nil) - err := HookDeleteTempFile(ctx, &fs, file) - asserts.NoError(err) - mockHandler.AssertExpectations(t) - } - - // 失败 - { - mockHandler := &FileHeaderMock{} - fs.Handler = mockHandler - mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, errors.New("")) - err := HookDeleteTempFile(ctx, &fs, file) - asserts.NoError(err) - mockHandler.AssertExpectations(t) - } - -} - -func TestGenericAfterUpload(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{}, - } - - ctx := context.Background() - file := &fsctx.FileStream{ - VirtualPath: "/我的文件", - Name: "test.txt", - } - - // 正常 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := GenericAfterUpload(ctx, &fs, file) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 文件已存在 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows( - mock.NewRows([]string{"name"}).AddRow("test.txt"), - ) - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrFileExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 文件已存在, 且为上传占位符 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows( - mock.NewRows([]string{"name", "upload_session_id"}).AddRow("test.txt", "1"), - ) - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrFileUploadSessionExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 插入失败 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrInsertFileRecord, err) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestFileSystem_Use(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{} - - hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - return nil - } - - // 添加一个 - fs.Use("BeforeUpload", hook) - asserts.Len(fs.Hooks["BeforeUpload"], 1) - - // 添加一个 - fs.Use("BeforeUpload", hook) - asserts.Len(fs.Hooks["BeforeUpload"], 2) - - // 不存在 - fs.Use("BeforeUpload2333", hook) - - asserts.NotPanics(func() { - for _, hookName := range []string{ - "AfterUpload", - "AfterValidateFailed", - "AfterUploadCanceled", - "BeforeFileDownload", - } { - fs.Use(hookName, hook) - } - }) - -} - -func TestFileSystem_Trigger(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - } - ctx := context.Background() - - hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fs.User.Storage++ - return nil - } - - // 一个 - fs.Use("BeforeUpload", hook) - err := fs.Trigger(ctx, "BeforeUpload", nil) - asserts.NoError(err) - asserts.Equal(uint64(1), fs.User.Storage) - - // 多个 - fs.Use("BeforeUpload", hook) - fs.Use("BeforeUpload", hook) - err = fs.Trigger(ctx, "BeforeUpload", nil) - asserts.NoError(err) - asserts.Equal(uint64(4), fs.User.Storage) - - // 多个,有失败 - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - asserts.Fail("following hooks executed") - return nil - }) - err = fs.Trigger(ctx, "BeforeUpload", nil) - asserts.Error(err) -} - -func TestHookValidateCapacity(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - Storage: 0, - Group: model.Group{ - MaxStorage: 11, - }, - }} - ctx := context.Background() - file := &fsctx.FileStream{Size: 11} - { - err := HookValidateCapacity(ctx, fs, file) - asserts.NoError(err) - } - { - file.Size = 12 - err := HookValidateCapacity(ctx, fs, file) - asserts.Error(err) - } -} - -func TestHookValidateCapacityDiff(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{User: &model.User{ - Group: model.Group{ - MaxStorage: 11, - }, - }} - file := model.File{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - - // 无需操作 - { - a.NoError(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 10})) - } - - // 需要验证 - { - a.Error(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 12})) - } - -} - -func TestHookResetPolicy(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - file := model.File{PolicyID: 2} - cache.Deletes([]string{"2"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(2, "local")) - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - err := HookResetPolicy(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文文件不存在 - { - cache.Deletes([]string{"2"}, "policy_") - ctx := context.Background() - err := HookResetPolicy(ctx, fs, nil) - asserts.Error(err) - } -} - -func TestHookCleanFileContent(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - file := &fsctx.FileStream{SavePath: "123/123"} - handlerMock := FileHeaderMock{} - handlerMock.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error")) - fs.Handler = handlerMock - err := HookCleanFileContent(context.Background(), fs, file) - asserts.Error(err) - handlerMock.AssertExpectations(t) -} - -func TestHookClearFileSize(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - ctx := context.WithValue( - context.Background(), - fsctx.FileModelCtx, - model.File{Model: gorm.Model{ID: 1}, Size: 10}, - ) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)"). - WithArgs("", 0, sqlmock.AnyArg(), 1, 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := HookClearFileSize(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文对象不存在 - { - ctx := context.Background() - err := HookClearFileSize(ctx, fs, nil) - asserts.Error(err) - } - -} - -func TestHookUpdateSourceName(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - SourceName: "new.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("", "new.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := HookUpdateSourceName(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文错误 - { - ctx := context.Background() - err := HookUpdateSourceName(ctx, fs, nil) - asserts.Error(err) - } -} - -func TestGenericAfterUpdate(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 是图像文件 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - PicInfo: "1,1", - } - newFile := &fsctx.FileStream{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - - handlerMock := FileHeaderMock{} - handlerMock.On("Delete", testMock.Anything, []string{"._thumb"}).Return([]string{}, nil) - fs.Handler = handlerMock - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)"). - WithArgs("", 10, sqlmock.AnyArg(), 1, 0). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := GenericAfterUpdate(ctx, fs, newFile) - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 原始文件上下文不存在 - { - newFile := &fsctx.FileStream{Size: 10} - ctx := context.Background() - err := GenericAfterUpdate(ctx, fs, newFile) - asserts.Error(err) - } - - // 无法更新数据库容量 - // 成功 是图像文件 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - PicInfo: "1,1", - } - newFile := &fsctx.FileStream{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs("", 10, sqlmock.AnyArg(), 1, 0). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - - err := GenericAfterUpdate(ctx, fs, newFile) - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestSlaveAfterUpload(t *testing.T) { - asserts := assert.New(t) - conf.SystemConfig.Mode = "slave" - fs, err := NewAnonymousFileSystem() - conf.SystemConfig.Mode = "master" - asserts.NoError(err) - - // 成功 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/callbakc", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - request.GeneralClient = clientMock - file := &fsctx.FileStream{ - Size: 10, - VirtualPath: "/my", - Name: "test.txt", - SavePath: "/not_exist", - } - err := SlaveAfterUpload(&serializer.UploadSession{Callback: "http://test/callbakc"})(context.Background(), fs, file) - clientMock.AssertExpectations(t) - asserts.NoError(err) - } - - // 跳过回调 - { - file := &fsctx.FileStream{ - Size: 10, - VirtualPath: "/my", - Name: "test.txt", - SavePath: "/not_exist", - } - err := SlaveAfterUpload(&serializer.UploadSession{})(context.Background(), fs, file) - asserts.NoError(err) - } -} - -func TestFileSystem_CleanHooks(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - Hooks: map[string][]Hook{ - "hook1": []Hook{}, - "hook2": []Hook{}, - "hook3": []Hook{}, - }, - } - - // 清理一个 - { - fs.CleanHooks("hook2") - asserts.Len(fs.Hooks, 2) - asserts.Contains(fs.Hooks, "hook1") - asserts.Contains(fs.Hooks, "hook3") - } - - // 清理全部 - { - fs.CleanHooks("") - asserts.Len(fs.Hooks, 0) - } -} - -func TestHookCancelContext(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{} - ctx, cancel := context.WithCancel(context.Background()) - - // empty ctx - { - asserts.NoError(HookCancelContext(ctx, fs, nil)) - select { - case <-ctx.Done(): - t.Errorf("Channel should not be closed") - default: - - } - } - - // with cancel ctx - { - ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) - asserts.NoError(HookCancelContext(ctx, fs, nil)) - _, ok := <-ctx.Done() - asserts.False(ok) - } -} - -func TestHookClearFileHeaderSize(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{Size: 10} - a.NoError(HookClearFileHeaderSize(context.Background(), fs, file)) - a.EqualValues(0, file.Size) -} - -func TestHookTruncateFileTo(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{} - a.NoError(HookTruncateFileTo(0)(context.Background(), fs, file)) - - fs.Handler = local.Driver{} - a.Error(HookTruncateFileTo(0)(context.Background(), fs, file)) -} - -func TestHookChunkUploaded(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - AppendStart: 10, - Size: 10, - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 20, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(20, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookChunkUploaded(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookChunkUploadFailed(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - AppendStart: 10, - Size: 10, - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 10, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookChunkUploadFailed(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookPopPlaceholderToFile(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookPopPlaceholderToFile("1,1")(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookPopPlaceholderToFileBySuffix(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{ - Policy: &model.Policy{Type: "cos"}, - } - file := &fsctx.FileStream{ - Name: "1.png", - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookPopPlaceholderToFile("")(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookDeleteUploadSession(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - cache.Set(UploadSessionCachePrefix+"TestHookDeleteUploadSession", "", 0) - a.NoError(HookDeleteUploadSession("TestHookDeleteUploadSession")(context.Background(), fs, file)) - _, ok := cache.Get(UploadSessionCachePrefix + "TestHookDeleteUploadSession") - a.False(ok) -} -func TestNewWebdavAfterUploadHook(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - req, _ := http.NewRequest("get", "http://localhost", nil) - req.Header.Add("X-Oc-Mtime", "1681521402") - req.Header.Add("OC-Checksum", "checksum") - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := NewWebdavAfterUploadHook(req)(context.Background(), fs, file) - a.NoError(err) - a.NoError(mock.ExpectationsWereMet()) - -} diff --git a/pkg/filesystem/image.go b/pkg/filesystem/image.go deleted file mode 100644 index 2563cc9d..00000000 --- a/pkg/filesystem/image.go +++ /dev/null @@ -1,219 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "fmt" - "os" - "sync" - - "runtime" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -/* ================ - 图像处理相关 - ================ -*/ - -// GetThumb 获取文件的缩略图 -func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentResponse, error) { - // 根据 ID 查找文件 - err := fs.resetFileIDIfNotExist(ctx, id) - if err != nil { - return nil, ErrObjectNotExist - } - - file := fs.FileTarget[0] - if !file.ShouldLoadThumb() { - return nil, ErrObjectNotExist - } - - w, h := fs.GenerateThumbnailSize(0, 0) - ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h}) - ctx = context.WithValue(ctx, fsctx.FileModelCtx, file) - res, err := fs.Handler.Thumb(ctx, &file) - if errors.Is(err, driver.ErrorThumbNotExist) { - // Regenerate thumb if the thumb is not initialized yet - if generateErr := fs.generateThumbnail(ctx, &file); generateErr == nil { - res, err = fs.Handler.Thumb(ctx, &file) - } else { - err = generateErr - } - } else if errors.Is(err, driver.ErrorThumbNotSupported) { - // Policy handler explicitly indicates thumb not available, check if proxy is enabled - if fs.Policy.CouldProxyThumb() { - // if thumb id marked as existed, redirect to "sidecar" thumb file. - if file.MetadataSerialized != nil && - file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusExist { - // redirect to sidecar file - res = &response.ContentResponse{ - Redirect: true, - } - res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0) - } else { - // if not exist, generate and upload the sidecar thumb. - if err = fs.generateThumbnail(ctx, &file); err == nil { - return fs.GetThumb(ctx, id) - } - } - } else { - // thumb not supported and proxy is disabled, mark as not available - _ = updateThumbStatus(&file, model.ThumbStatusNotAvailable) - } - } - - if err == nil && conf.SystemConfig.Mode == "master" { - res.MaxAge = model.GetIntSetting("preview_timeout", 60) - } - - return res, err -} - -// thumbPool 要使用的任务池 -var thumbPool *Pool -var once sync.Once - -// Pool 带有最大配额的任务池 -type Pool struct { - // 容量 - worker chan int -} - -// Init 初始化任务池 -func getThumbWorker() *Pool { - once.Do(func() { - maxWorker := model.GetIntSetting("thumb_max_task_count", -1) - if maxWorker <= 0 { - maxWorker = runtime.GOMAXPROCS(0) - } - thumbPool = &Pool{ - worker: make(chan int, maxWorker), - } - util.Log().Debug("Initialize thumbnails task queue with: WorkerNum = %d", maxWorker) - }) - return thumbPool -} -func (pool *Pool) addWorker() { - pool.worker <- 1 - util.Log().Debug("Worker added to thumbnails task queue.") -} -func (pool *Pool) releaseWorker() { - util.Log().Debug("Worker released from thumbnails task queue.") - <-pool.worker -} - -// generateThumbnail generates thumb for given file, upload the thumb file back with given suffix -func (fs *FileSystem) generateThumbnail(ctx context.Context, file *model.File) error { - // 新建上下文 - newCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - // TODO: check file size - - if file.Size > uint64(model.GetIntSetting("thumb_max_src_size", 31457280)) { - _ = updateThumbStatus(file, model.ThumbStatusNotAvailable) - return errors.New("file too large") - } - - getThumbWorker().addWorker() - defer getThumbWorker().releaseWorker() - - // 获取文件数据 - source, err := fs.Handler.Get(newCtx, file.SourceName) - if err != nil { - return fmt.Errorf("faield to fetch original file %q: %w", file.SourceName, err) - } - defer source.Close() - - // Provide file source path for local policy files - src := "" - if conf.SystemConfig.Mode == "slave" || file.GetPolicy().Type == "local" { - src = file.SourceName - } - - thumbRes, err := thumb.Generators.Generate(ctx, source, src, file.Name, model.GetSettingByNames( - "thumb_width", - "thumb_height", - "thumb_builtin_enabled", - "thumb_vips_enabled", - "thumb_ffmpeg_enabled", - "thumb_libreoffice_enabled", - "thumb_libraw_enabled", - )) - if err != nil { - _ = updateThumbStatus(file, model.ThumbStatusNotAvailable) - return fmt.Errorf("failed to generate thumb for %q: %w", file.Name, err) - } - - defer os.Remove(thumbRes.Path) - - thumbFile, err := os.Open(thumbRes.Path) - if err != nil { - return fmt.Errorf("failed to open temp thumb %q: %w", thumbRes.Path, err) - } - - defer thumbFile.Close() - fileInfo, err := thumbFile.Stat() - if err != nil { - return fmt.Errorf("failed to stat temp thumb %q: %w", thumbRes.Path, err) - } - - if err = fs.Handler.Put(newCtx, &fsctx.FileStream{ - Mode: fsctx.Overwrite, - File: thumbFile, - Seeker: thumbFile, - Size: uint64(fileInfo.Size()), - SavePath: file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"), - }); err != nil { - return fmt.Errorf("failed to save thumb for %q: %w", file.Name, err) - } - - if model.IsTrueVal(model.GetSettingByName("thumb_gc_after_gen")) { - util.Log().Debug("generateThumbnail runtime.GC") - runtime.GC() - } - - // Mark this file as thumb available - err = updateThumbStatus(file, model.ThumbStatusExist) - - // 失败时删除缩略图文件 - if err != nil { - _, _ = fs.Handler.Delete(newCtx, []string{file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")}) - } - - return nil -} - -// GenerateThumbnailSize 获取要生成的缩略图的尺寸 -func (fs *FileSystem) GenerateThumbnailSize(w, h int) (uint, uint) { - return uint(model.GetIntSetting("thumb_width", 400)), uint(model.GetIntSetting("thumb_height", 300)) -} - -func updateThumbStatus(file *model.File, status string) error { - if file.Model.ID > 0 { - meta := map[string]string{ - model.ThumbStatusMetadataKey: status, - } - - if status == model.ThumbStatusExist { - meta[model.ThumbSidecarMetadataKey] = "true" - } - - return file.UpdateMetadata(meta) - } else { - if file.MetadataSerialized == nil { - file.MetadataSerialized = map[string]string{} - } - - file.MetadataSerialized[model.ThumbStatusMetadataKey] = status - } - - return nil -} diff --git a/pkg/filesystem/image_test.go b/pkg/filesystem/image_test.go deleted file mode 100644 index 41808583..00000000 --- a/pkg/filesystem/image_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/thumbmock" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" - testMock "github.com/stretchr/testify/mock" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_GetThumb(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{User: &model.User{}} - - // file not found - { - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, ErrObjectNotExist) - a.Nil(res) - a.NoError(mock.ExpectationsWereMet()) - } - - // thumb not exist - { - fs.SetTargetFile(&[]model.File{{ - MetadataSerialized: map[string]string{ - model.ThumbStatusMetadataKey: model.ThumbStatusNotAvailable, - }, - Policy: model.Policy{Type: "mock"}, - }}) - fs.FileTarget[0].Policy.ID = 1 - - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, ErrObjectNotExist) - a.Nil(res) - } - - // thumb not initialized, also failed to generate - { - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - Size: 31457281, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "file too large") - a.Nil(res.Content) - } - - // thumb not initialized, failed to get source - { - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, errors.New("error")) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "error") - a.Nil(res.Content) - } - - // thumb not initialized, no available generators - { - thumb.Generators = []thumb.Generator{} - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "local"}, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, thumb.ErrNotAvailable) - a.Nil(res) - } - - // thumb not initialized, thumb generated but cannot be open - { - mockGenerator := &thumbmock.GeneratorMock{} - thumb.Generators = []thumb.Generator{mockGenerator} - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - }}) - cache.Set("setting_thumb_vips_enabled", "1", 0) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil) - mockGenerator.On("Generate", testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything). - Return(&thumb.Result{Path: "not_exit_thumb"}, nil) - - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "failed to open temp thumb") - a.Nil(res.Content) - testHandller2.AssertExpectations(t) - mockGenerator.AssertExpectations(t) - } -} - -func TestFileSystem_ThumbWorker(t *testing.T) { - asserts := assert.New(t) - - asserts.NotPanics(func() { - getThumbWorker().addWorker() - getThumbWorker().releaseWorker() - }) -} diff --git a/pkg/filesystem/manage.go b/pkg/filesystem/manage.go deleted file mode 100644 index c77c9d92..00000000 --- a/pkg/filesystem/manage.go +++ /dev/null @@ -1,479 +0,0 @@ -package filesystem - -import ( - "context" - "fmt" - "path" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -/* ================= - 文件/目录管理 - ================= -*/ - -// Rename 重命名对象 -func (fs *FileSystem) Rename(ctx context.Context, dir, file []uint, new string) (err error) { - // 验证新名字 - if !fs.ValidateLegalName(ctx, new) || (len(file) > 0 && !fs.ValidateExtension(ctx, new)) { - return ErrIllegalObjectName - } - - // 如果源对象是文件 - if len(file) > 0 { - fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID) - if err != nil || len(fileObject) == 0 { - return ErrPathNotExist - } - - err = fileObject[0].Rename(new) - if err != nil { - return ErrFileExisted - } - return nil - } - - if len(dir) > 0 { - folderObject, err := model.GetFoldersByIDs([]uint{dir[0]}, fs.User.ID) - if err != nil || len(folderObject) == 0 { - return ErrPathNotExist - } - - err = folderObject[0].Rename(new) - if err != nil { - return ErrFileExisted - } - return nil - } - - return ErrPathNotExist -} - -// Copy 复制src目录下的文件或目录到dst, -// 暂时只支持单文件 -func (fs *FileSystem) Copy(ctx context.Context, dirs, files []uint, src, dst string) error { - // 获取目的目录 - isDstExist, dstFolder := fs.IsPathExist(dst) - isSrcExist, srcFolder := fs.IsPathExist(src) - // 不存在时返回空的结果 - if !isDstExist || !isSrcExist { - return ErrPathNotExist - } - - // 记录复制的文件的总容量 - var newUsedStorage uint64 - - // 设置webdav目标名 - if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok { - dstFolder.WebdavDstName = dstName - } - - // 复制目录 - if len(dirs) > 0 { - subFileSizes, err := srcFolder.CopyFolderTo(dirs[0], dstFolder) - if err != nil { - return ErrObjectNotExist.WithError(err) - } - newUsedStorage += subFileSizes - } - - // 复制文件 - if len(files) > 0 { - subFileSizes, err := srcFolder.MoveOrCopyFileTo(files, dstFolder, true) - if err != nil { - return ErrObjectNotExist.WithError(err) - } - newUsedStorage += subFileSizes - } - - // 扣除容量 - fs.User.IncreaseStorageWithoutCheck(newUsedStorage) - - return nil -} - -// Move 移动文件和目录, 将id列表dirs和files从src移动至dst -func (fs *FileSystem) Move(ctx context.Context, dirs, files []uint, src, dst string) error { - // 获取目的目录 - isDstExist, dstFolder := fs.IsPathExist(dst) - isSrcExist, srcFolder := fs.IsPathExist(src) - // 不存在时返回空的结果 - if !isDstExist || !isSrcExist { - return ErrPathNotExist - } - - // 设置webdav目标名 - if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok { - dstFolder.WebdavDstName = dstName - } - - // 处理目录及子文件移动 - err := srcFolder.MoveFolderTo(dirs, dstFolder) - if err != nil { - return ErrFileExisted.WithError(err) - } - - // 处理文件移动 - _, err = srcFolder.MoveOrCopyFileTo(files, dstFolder, false) - if err != nil { - return ErrFileExisted.WithError(err) - } - - // 移动文件 - - return err -} - -// Delete 递归删除对象, force 为 true 时强制删除文件记录,忽略物理删除是否成功; -// unlink 为 true 时只删除虚拟文件系统的文件记录,不删除物理文件。 -func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force, unlink bool) error { - // 已删除的文件ID - var deletedFiles = make([]*model.File, 0, len(fs.FileTarget)) - // 删除失败的文件的父目录ID - - // 所有文件的ID - var allFiles = make([]*model.File, 0, len(fs.FileTarget)) - - // 列出要删除的目录 - if len(dirs) > 0 { - err := fs.ListDeleteDirs(ctx, dirs) - if err != nil { - return err - } - } - - // 列出要删除的文件 - if len(files) > 0 { - err := fs.ListDeleteFiles(ctx, files) - if err != nil { - return err - } - } - - // 去除待删除文件中包含软连接的部分 - filesToBeDelete, err := model.RemoveFilesWithSoftLinks(fs.FileTarget) - if err != nil { - return ErrDBListObjects.WithError(err) - } - - // 根据存储策略将文件分组 - policyGroup := fs.GroupFileByPolicy(ctx, filesToBeDelete) - - // 按照存储策略分组删除对象 - failed := make(map[uint][]string) - if !unlink { - failed = fs.deleteGroupedFile(ctx, policyGroup) - } - - // 整理删除结果 - for i := 0; i < len(fs.FileTarget); i++ { - if !util.ContainsString(failed[fs.FileTarget[i].PolicyID], fs.FileTarget[i].SourceName) { - // 已成功删除的文件 - deletedFiles = append(deletedFiles, &fs.FileTarget[i]) - } - - // 全部文件 - allFiles = append(allFiles, &fs.FileTarget[i]) - } - - // 如果强制删除,则将全部文件视为删除成功 - if force { - deletedFiles = allFiles - } - - // 删除文件记录 - err = model.DeleteFiles(deletedFiles, fs.User.ID) - if err != nil { - return ErrDBDeleteObjects.WithError(err) - } - - // 删除文件记录对应的分享记录 - // TODO 先取消分享再删除文件 - deletedFileIDs := make([]uint, len(deletedFiles)) - for k, file := range deletedFiles { - deletedFileIDs[k] = file.ID - } - - model.DeleteShareBySourceIDs(deletedFileIDs, false) - - // 如果文件全部删除成功,继续删除目录 - if len(deletedFiles) == len(allFiles) { - var allFolderIDs = make([]uint, 0, len(fs.DirTarget)) - for _, value := range fs.DirTarget { - allFolderIDs = append(allFolderIDs, value.ID) - } - err = model.DeleteFolderByIDs(allFolderIDs) - if err != nil { - return ErrDBDeleteObjects.WithError(err) - } - - // 删除目录记录对应的分享记录 - model.DeleteShareBySourceIDs(allFolderIDs, true) - } - - if notDeleted := len(fs.FileTarget) - len(deletedFiles); notDeleted > 0 { - return serializer.NewError( - serializer.CodeNotFullySuccess, - fmt.Sprintf("Failed to delete %d file(s).", notDeleted), - nil, - ) - } - - return nil -} - -// ListDeleteDirs 递归列出要删除目录,及目录下所有文件 -func (fs *FileSystem) ListDeleteDirs(ctx context.Context, ids []uint) error { - // 列出所有递归子目录 - folders, err := model.GetRecursiveChildFolder(ids, fs.User.ID, true) - if err != nil { - return ErrDBListObjects.WithError(err) - } - - // 忽略根目录 - for i := 0; i < len(folders); i++ { - if folders[i].ParentID == nil { - folders = append(folders[:i], folders[i+1:]...) - break - } - } - - fs.SetTargetDir(&folders) - - // 检索目录下的子文件 - files, err := model.GetChildFilesOfFolders(&folders) - if err != nil { - return ErrDBListObjects.WithError(err) - } - fs.SetTargetFile(&files) - - return nil -} - -// ListDeleteFiles 根据给定的路径列出要删除的文件 -func (fs *FileSystem) ListDeleteFiles(ctx context.Context, ids []uint) error { - files, err := model.GetFilesByIDs(ids, fs.User.ID) - if err != nil { - return ErrDBListObjects.WithError(err) - } - fs.SetTargetFile(&files) - return nil -} - -// List 列出路径下的内容, -// pathProcessor为最终对象路径的处理钩子。 -// 有些情况下(如在分享页面列对象)时, -// 路径需要截取掉被分享目录路径之前的部分。 -func (fs *FileSystem) List(ctx context.Context, dirPath string, pathProcessor func(string) string) ([]serializer.Object, error) { - // 获取父目录 - isExist, folder := fs.IsPathExist(dirPath) - if !isExist { - return nil, ErrPathNotExist - } - fs.SetTargetDir(&[]model.Folder{*folder}) - - var parentPath = path.Join(folder.Position, folder.Name) - var childFolders []model.Folder - var childFiles []model.File - - // 获取子目录 - childFolders, _ = folder.GetChildFolder() - - // 获取子文件 - childFiles, _ = folder.GetChildFiles() - - return fs.listObjects(ctx, parentPath, childFiles, childFolders, pathProcessor), nil -} - -// ListPhysical 列出存储策略中的外部目录 -// TODO:测试 -func (fs *FileSystem) ListPhysical(ctx context.Context, dirPath string) ([]serializer.Object, error) { - if err := fs.DispatchHandler(); fs.Policy == nil || err != nil { - return nil, ErrUnknownPolicyType - } - - // 存储策略不支持列取时,返回空结果 - if !fs.Policy.CanStructureBeListed() { - return nil, nil - } - - // 列取路径 - objects, err := fs.Handler.List(ctx, dirPath, false) - if err != nil { - return nil, err - } - - var ( - folders []model.Folder - ) - for _, object := range objects { - if object.IsDir { - folders = append(folders, model.Folder{ - Name: object.Name, - }) - } - } - - return fs.listObjects(ctx, dirPath, nil, folders, nil), nil -} - -func (fs *FileSystem) listObjects(ctx context.Context, parent string, files []model.File, folders []model.Folder, pathProcessor func(string) string) []serializer.Object { - // 分享文件的ID - shareKey := "" - if key, ok := ctx.Value(fsctx.ShareKeyCtx).(string); ok { - shareKey = key - } - - // 汇总处理结果 - objects := make([]serializer.Object, 0, len(files)+len(folders)) - - // 所有对象的父目录 - var processedPath string - - for _, subFolder := range folders { - // 路径处理钩子, - // 所有对象父目录都是一样的,所以只处理一次 - if processedPath == "" { - if pathProcessor != nil { - processedPath = pathProcessor(parent) - } else { - processedPath = parent - } - } - - objects = append(objects, serializer.Object{ - ID: hashid.HashID(subFolder.ID, hashid.FolderID), - Name: subFolder.Name, - Path: processedPath, - Size: 0, - Type: "dir", - Date: subFolder.UpdatedAt, - CreateDate: subFolder.CreatedAt, - }) - } - - for _, file := range files { - if processedPath == "" { - if pathProcessor != nil { - processedPath = pathProcessor(parent) - } else { - processedPath = parent - } - } - - if file.UploadSessionID == nil { - newFile := serializer.Object{ - ID: hashid.HashID(file.ID, hashid.FileID), - Name: file.Name, - Path: processedPath, - Thumb: file.ShouldLoadThumb(), - Size: file.Size, - Type: "file", - Date: file.UpdatedAt, - SourceEnabled: file.GetPolicy().IsOriginLinkEnable, - CreateDate: file.CreatedAt, - } - if shareKey != "" { - newFile.Key = shareKey - } - objects = append(objects, newFile) - } - } - - return objects -} - -// CreateDirectory 根据给定的完整创建目录,支持递归创建。如果目录已存在,则直接 -// 返回已存在的目录。 -func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*model.Folder, error) { - if fullPath == "." || fullPath == "" { - return nil, ErrRootProtected - } - - if fullPath == "/" { - if fs.Root != nil { - return fs.Root, nil - } - return fs.User.Root() - } - - // 获取要创建目录的父路径和目录名 - fullPath = path.Clean(fullPath) - base := path.Dir(fullPath) - dir := path.Base(fullPath) - - // 去掉结尾空格 - dir = strings.TrimRight(dir, " ") - - // 检查目录名是否合法 - if !fs.ValidateLegalName(ctx, dir) { - return nil, ErrIllegalObjectName - } - - // 父目录是否存在 - isExist, parent := fs.IsPathExist(base) - if !isExist { - newParent, err := fs.CreateDirectory(ctx, base) - if err != nil { - return nil, err - } - parent = newParent - } - - // 是否有同名文件 - if ok, _ := fs.IsChildFileExist(parent, dir); ok { - return nil, ErrFileExisted - } - - // 创建目录 - newFolder := model.Folder{ - Name: dir, - ParentID: &parent.ID, - OwnerID: fs.User.ID, - } - _, err := newFolder.Create() - - if err != nil { - return nil, fmt.Errorf("failed to create folder: %w", err) - } - - return &newFolder, nil -} - -// SaveTo 将别人分享的文件转存到目标路径下 -func (fs *FileSystem) SaveTo(ctx context.Context, path string) error { - // 获取父目录 - isExist, folder := fs.IsPathExist(path) - if !isExist { - return ErrPathNotExist - } - - var ( - totalSize uint64 - err error - ) - - if len(fs.DirTarget) > 0 { - totalSize, err = fs.DirTarget[0].CopyFolderTo(fs.DirTarget[0].ID, folder) - } else { - parent := model.Folder{ - OwnerID: fs.FileTarget[0].UserID, - } - parent.ID = fs.FileTarget[0].FolderID - totalSize, err = parent.MoveOrCopyFileTo([]uint{fs.FileTarget[0].ID}, folder, true) - } - - // 扣除用户容量 - fs.User.IncreaseStorageWithoutCheck(totalSize) - if err != nil { - return ErrFileExisted.WithError(err) - } - - return nil -} diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go deleted file mode 100644 index 1f2cc1ae..00000000 --- a/pkg/filesystem/manage_test.go +++ /dev/null @@ -1,848 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "os" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - testMock "github.com/stretchr/testify/mock" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_ListPhysical(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - // 未知存储策略 - { - fs.Policy.Type = "unknown" - res, err := fs.ListPhysical(ctx, "/") - asserts.Equal(ErrUnknownPolicyType, err) - asserts.Empty(res) - fs.Policy.Type = "mock" - } - - // 无法列取目录 - { - testHandler := new(FileHeaderMock) - testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return([]response.Object{}, errors.New("error")) - fs.Handler = testHandler - res, err := fs.ListPhysical(ctx, "/") - asserts.EqualError(err, "error") - asserts.Empty(res) - } - - // 成功 - { - testHandler := new(FileHeaderMock) - testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return( - []response.Object{{IsDir: true, Name: "1"}, {IsDir: false, Name: "2"}}, - nil, - ) - fs.Handler = testHandler - res, err := fs.ListPhysical(ctx, "/") - asserts.NoError(err) - asserts.Len(res, 1) - asserts.Equal("1", res[0].Name) - } -} - -func TestFileSystem_List(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 成功,子目录包含文件和路径,不使用路径处理钩子 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt")) - objects, err := fs.List(ctx, "/folder", nil) - asserts.Len(objects, 4) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录包含文件和路径,不使用路径处理钩子,包含分享key - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt")) - ctxWithKey := context.WithValue(ctx, fsctx.ShareKeyCtx, "share") - objects, err = fs.List(ctxWithKey, "/folder", nil) - asserts.Len(objects, 4) - asserts.Equal("share", objects[3].Key) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录包含文件和路径,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}).AddRow(6, "sub_folder1", "/folder").AddRow(7, "sub_folder2", "/folder")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 4) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - for _, value := range objects { - asserts.Contains(value.Path, "prefix/") - } - - // 成功,子目录包含路径,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 2) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - for _, value := range objects { - asserts.Contains(value.Path, "prefix/") - } - - // 成功,子目录下为空,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"})) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 0) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录路径不存在 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"})) - - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 0) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_CreateDirectory(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 目录名非法 - _, err := fs.CreateDirectory(ctx, "/ad/a+?") - asserts.Equal(ErrIllegalObjectName, err) - - // 存在同名文件 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab")) - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.Equal(ErrFileExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 存在同名目录,直接返回 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // ab - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - res, err := fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.EqualValues(3, res.ID) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功创建 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功创建, 递归创建父目录 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ad", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ab - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 底层创建失败 - // 成功创建 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ad", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)).WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 直接创建根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - _, err = fs.CreateDirectory(ctx, "/") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 直接创建根目录, 重设根目录 - fs.Root = &model.Folder{} - _, err = fs.CreateDirectory(ctx, "/") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_ListDeleteFiles(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt").AddRow(2, "2.txt")) - err := fs.ListDeleteFiles(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - err := fs.ListDeleteFiles(context.Background(), []uint{1}) - asserts.Error(err) - asserts.Equal(serializer.CodeDBError, err.(serializer.AppError).Code) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_ListDeleteDirs(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 成功 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "1.txt"). - AddRow(5, "2.txt"). - AddRow(6, "3.txt"), - ) - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.Len(fs.FileTarget, 3) - asserts.Len(fs.DirTarget, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功,忽略根目录 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, nil). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "1.txt"). - AddRow(5, "2.txt"). - AddRow(6, "3.txt"), - ) - fs.CleanTargets() - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.Len(fs.FileTarget, 3) - asserts.Len(fs.DirTarget, 2) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 检索文件发生错误 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnError(errors.New("error")) - fs.CleanTargets() - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.Error(err) - asserts.Len(fs.DirTarget, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } - // 检索目录发生错误 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_Delete(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 0, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - //全部未成功,强制 - { - fs.CleanTargets() - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}). - AddRow(4, "1.txt", "1.txt", 365, 1), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2)) - // 两次查询软连接 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(365, "local")) - // 删除文件记录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除目录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - - fs.FileTarget = []model.File{} - fs.DirTarget = []model.Folder{} - err := fs.Delete(ctx, []uint{1}, []uint{1}, true, false) - asserts.NoError(err) - } - //全部成功 - { - fs.CleanTargets() - file, err := os.Create(util.RelativePath("1.txt")) - file2, err := os.Create(util.RelativePath("2.txt")) - file.Close() - file2.Close() - asserts.NoError(err) - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}). - AddRow(4, "1.txt", "1.txt", 602, 1), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2)) - // 两次查询软连接 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(602, "local")) - // 删除文件记录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除目录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - - fs.FileTarget = []model.File{} - fs.DirTarget = []model.Folder{} - err = fs.Delete(ctx, []uint{1}, []uint{1}, false, false) - asserts.NoError(err) - } - -} - -func TestFileSystem_Copy(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - err := fs.Copy(ctx, []uint{}, []uint{}, "/src", "/dst") - asserts.Equal(ErrPathNotExist, err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 复制目录出错 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "dst"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "src"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - err := fs.Copy(ctx, []uint{1}, []uint{}, "/src", "/dst") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - -} - -func TestFileSystem_Move(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - err := fs.Move(ctx, []uint{}, []uint{}, "/src", "/dst") - asserts.Equal(ErrPathNotExist, err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 移动目录出错 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "dst"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "src"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - err := fs.Move(ctx, []uint{1}, []uint{}, "/src", "/dst") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_Rename(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{}, - } - ctx := context.Background() - - // 重命名文件 成功 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)"). - WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 重命名文件 不存在 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 重命名文件 失败 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)"). - WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrFileExisted, err) - } - - // 重命名目录 成功 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("new", 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 重命名目录 不存在 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 重命名目录 失败 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("new", 10). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrFileExisted, err) - } - - // 未选中任何对象 - { - err := fs.Rename(ctx, []uint{}, []uint{}, "new") - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 新名字是目录,不合法 - { - err := fs.Rename(ctx, []uint{10}, []uint{}, "ne/w") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是文件,不合法 - { - err := fs.Rename(ctx, []uint{}, []uint{10}, "ne/w") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是文件,扩展名不合法 - { - fs.Policy.OptionsSerialized.FileType = []string{"txt"} - err := fs.Rename(ctx, []uint{}, []uint{10}, "1.jpg") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是目录,不应该检测扩展名 - { - fs.Policy.OptionsSerialized.FileType = []string{"txt"} - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } -} - -func TestFileSystem_SaveTo(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 单文件 失败 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - fs.SetTargetFile(&[]model.File{{Name: "test.txt"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 目录 成功 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - fs.SetTargetDir(&[]model.Folder{{Name: "folder"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 父目录不存在 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - fs.SetTargetDir(&[]model.Folder{{Name: "folder"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/oauth/mutex.go b/pkg/filesystem/oauth/mutex.go deleted file mode 100644 index 41f588d0..00000000 --- a/pkg/filesystem/oauth/mutex.go +++ /dev/null @@ -1,25 +0,0 @@ -package oauth - -import "sync" - -// CredentialLock 针对存储策略凭证的锁 -type CredentialLock interface { - Lock(uint) - Unlock(uint) -} - -var GlobalMutex = mutexMap{} - -type mutexMap struct { - locks sync.Map -} - -func (m *mutexMap) Lock(id uint) { - lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) - lock.(*sync.Mutex).Lock() -} - -func (m *mutexMap) Unlock(id uint) { - lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) - lock.(*sync.Mutex).Unlock() -} diff --git a/pkg/filesystem/oauth/token.go b/pkg/filesystem/oauth/token.go deleted file mode 100644 index cdc5cf05..00000000 --- a/pkg/filesystem/oauth/token.go +++ /dev/null @@ -1,8 +0,0 @@ -package oauth - -import "context" - -type TokenProvider interface { - UpdateCredential(ctx context.Context, isSlave bool) error - AccessToken() string -} diff --git a/pkg/filesystem/path.go b/pkg/filesystem/path.go deleted file mode 100644 index b0637aa7..00000000 --- a/pkg/filesystem/path.go +++ /dev/null @@ -1,75 +0,0 @@ -package filesystem - -import ( - "path" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -/* ================= - 路径/目录相关 - ================= -*/ - -// IsPathExist 返回给定目录是否存在 -// 如果存在就返回目录 -func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) { - pathList := util.SplitPath(path) - if len(pathList) == 0 { - return false, nil - } - - // 递归步入目录 - // TODO:测试新增 - var currentFolder *model.Folder - - // 如果已设定跟目录对象,则从给定目录向下遍历 - if fs.Root != nil { - currentFolder = fs.Root - } - - for _, folderName := range pathList { - var err error - - // 根目录 - if folderName == "/" { - if currentFolder != nil { - continue - } - currentFolder, err = fs.User.Root() - if err != nil { - return false, nil - } - } else { - currentFolder, err = currentFolder.GetChild(folderName) - if err != nil { - return false, nil - } - } - } - - return true, currentFolder -} - -// IsFileExist 返回给定路径的文件是否存在 -func (fs *FileSystem) IsFileExist(fullPath string) (bool, *model.File) { - basePath := path.Dir(fullPath) - fileName := path.Base(fullPath) - - // 获得父目录 - exist, parent := fs.IsPathExist(basePath) - if !exist { - return false, nil - } - - file, err := parent.GetChildFile(fileName) - - return err == nil, file -} - -// IsChildFileExist 确定folder目录下是否有名为name的文件 -func (fs *FileSystem) IsChildFileExist(folder *model.Folder, name string) (bool, *model.File) { - file, err := folder.GetChildFile(name) - return err == nil, file -} diff --git a/pkg/filesystem/path_test.go b/pkg/filesystem/path_test.go deleted file mode 100644 index e4065a4f..00000000 --- a/pkg/filesystem/path_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package filesystem - -import ( - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_IsFileExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - exist, file := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(1), file.ID) - } - - // 文件不存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - exist, _ := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - } - - // 父目录不存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - exist, _ := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - } -} - -func TestFileSystem_IsPathExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 查询根目录 - { - path := "/" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(1), folder.ID) - } - - // 深层路径 - { - path := "/1/2/3" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "1"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(4), folder.ID) - } - - // 深层路径 重设根目录为/1 - { - path := "/2/3" - fs.Root = &model.Folder{Name: "1", Model: gorm.Model{ID: 2}, OwnerID: 1} - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(4), folder.ID) - fs.Root = nil - } - - // 深层 不存在 - { - path := "/1/2/3" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "1"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - asserts.Nil(folder) - } - -} - -func TestFileSystem_IsChildFileExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - folder := model.Folder{ - Model: gorm.Model{ID: 1}, - Name: "123", - Position: "/", - } - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "321"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "321")) - exist, childFile := fs.IsChildFileExist(&folder, "321") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal("/123", childFile.Position) -} diff --git a/pkg/filesystem/response/common.go b/pkg/filesystem/response/common.go deleted file mode 100644 index a6c9a1da..00000000 --- a/pkg/filesystem/response/common.go +++ /dev/null @@ -1,32 +0,0 @@ -package response - -import ( - "io" - "time" -) - -// ContentResponse 获取文件内容类方法的通用返回值。 -// 有些上传策略需要重定向, -// 有些直接写文件数据到浏览器 -type ContentResponse struct { - Redirect bool - Content RSCloser - URL string - MaxAge int -} - -// RSCloser 存储策略适配器返回的文件流,有些策略需要带有Closer -type RSCloser interface { - io.ReadSeeker - io.Closer -} - -// Object 列出文件、目录时返回的对象 -type Object struct { - Name string `json:"name"` - RelativePath string `json:"relative_path"` - Source string `json:"source"` - Size uint64 `json:"size"` - IsDir bool `json:"is_dir"` - LastModify time.Time `json:"last_modify"` -} diff --git a/pkg/filesystem/tests/file1.txt b/pkg/filesystem/tests/file1.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/filesystem/tests/file2.txt b/pkg/filesystem/tests/file2.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/filesystem/tests/test.zip b/pkg/filesystem/tests/test.zip deleted file mode 100644 index 316212ee..00000000 Binary files a/pkg/filesystem/tests/test.zip and /dev/null differ diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go deleted file mode 100644 index 08dde53d..00000000 --- a/pkg/filesystem/upload.go +++ /dev/null @@ -1,245 +0,0 @@ -package filesystem - -import ( - "context" - "os" - "path" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/gofrs/uuid" -) - -/* ================ - 上传处理相关 - ================ -*/ - -const ( - UploadSessionMetaKey = "upload_session" - UploadSessionCtx = "uploadSession" - UserCtx = "user" - UploadSessionCachePrefix = "callback_" -) - -// Upload 上传文件 -func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err error) { - // 上传前的钩子 - err = fs.Trigger(ctx, "BeforeUpload", file) - if err != nil { - request.BlackHole(file) - return err - } - - // 生成文件名和路径, - var savePath string - if file.SavePath == "" { - // 如果是更新操作就从上下文中获取 - if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { - savePath = originFile.SourceName - } else { - savePath = fs.GenerateSavePath(ctx, file) - } - file.SavePath = savePath - } - - // 保存文件 - if file.Mode&fsctx.Nop != fsctx.Nop { - // 处理客户端未完成上传时,关闭连接 - go fs.CancelUpload(ctx, savePath, file) - - err = fs.Handler.Put(ctx, file) - if err != nil { - fs.Trigger(ctx, "AfterUploadFailed", file) - return err - } - } - - // 上传完成后的钩子 - err = fs.Trigger(ctx, "AfterUpload", file) - - if err != nil { - // 上传完成后续处理失败 - followUpErr := fs.Trigger(ctx, "AfterValidateFailed", file) - // 失败后再失败... - if followUpErr != nil { - util.Log().Debug("AfterValidateFailed hook execution failed: %s", followUpErr) - } - - return err - } - - return nil -} - -// GenerateSavePath 生成要存放文件的路径 -// TODO 完善测试 -func (fs *FileSystem) GenerateSavePath(ctx context.Context, file fsctx.FileHeader) string { - fileInfo := file.Info() - return path.Join( - fs.Policy.GeneratePath( - fs.User.Model.ID, - fileInfo.VirtualPath, - ), - fs.Policy.GenerateFileName( - fs.User.Model.ID, - fileInfo.FileName, - ), - ) - -} - -// CancelUpload 监测客户端取消上传 -func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file fsctx.FileHeader) { - var reqContext context.Context - if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok { - reqContext = ginCtx.Request.Context() - } else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok { - reqContext = reqCtx - } else { - return - } - - select { - case <-reqContext.Done(): - select { - case <-ctx.Done(): - // 客户端正常关闭,不执行操作 - default: - // 客户端取消上传,删除临时文件 - util.Log().Debug("Client canceled upload.") - if fs.Hooks["AfterUploadCanceled"] == nil { - return - } - err := fs.Trigger(ctx, "AfterUploadCanceled", file) - if err != nil { - util.Log().Debug("AfterUploadCanceled hook execution failed: %s", err) - } - } - - } -} - -// CreateUploadSession 创建上传会话 -func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileStream) (*serializer.UploadCredential, error) { - // 获取相关有效期设置 - callBackSessionTTL := model.GetIntSetting("upload_session_timeout", 86400) - - callbackKey := uuid.Must(uuid.NewV4()).String() - fileSize := file.Size - - // 创建占位的文件,同时校验文件信息 - file.Mode = fsctx.Nop - if callbackKey != "" { - file.UploadSessionID = &callbackKey - } - - fs.Use("BeforeUpload", HookValidateFile) - fs.Use("BeforeUpload", HookValidateCapacity) - - // 验证文件规格 - if err := fs.Upload(ctx, file); err != nil { - return nil, err - } - - uploadSession := &serializer.UploadSession{ - Key: callbackKey, - UID: fs.User.ID, - Policy: *fs.Policy, - VirtualPath: file.VirtualPath, - Name: file.Name, - Size: fileSize, - SavePath: file.SavePath, - LastModified: file.LastModified, - CallbackSecret: util.RandStringRunes(32), - } - - // 获取上传凭证 - credential, err := fs.Handler.Token(ctx, int64(callBackSessionTTL), uploadSession, file) - if err != nil { - return nil, err - } - - // 创建占位符 - if !fs.Policy.IsUploadPlaceholderWithSize() { - fs.Use("AfterUpload", HookClearFileHeaderSize) - } - fs.Use("AfterUpload", GenericAfterUpload) - ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true) - if err := fs.Upload(ctx, file); err != nil { - return nil, err - } - - // 创建回调会话 - err = cache.Set( - UploadSessionCachePrefix+callbackKey, - *uploadSession, - callBackSessionTTL, - ) - if err != nil { - return nil, err - } - - // 补全上传凭证其他信息 - credential.Expires = time.Now().Add(time.Duration(callBackSessionTTL) * time.Second).Unix() - - return credential, nil -} - -// UploadFromStream 从文件流上传文件 -func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error { - if resetPolicy { - // 重设存储策略 - fs.Policy = &fs.User.Policy - err := fs.DispatchHandler() - if err != nil { - return err - } - } - - // 给文件系统分配钩子 - fs.Lock.Lock() - if fs.Hooks == nil { - fs.Use("BeforeUpload", HookValidateFile) - fs.Use("BeforeUpload", HookValidateCapacity) - fs.Use("AfterUploadCanceled", HookDeleteTempFile) - fs.Use("AfterUpload", GenericAfterUpload) - fs.Use("AfterValidateFailed", HookDeleteTempFile) - } - fs.Lock.Unlock() - - // 开始上传 - return fs.Upload(ctx, file) -} - -// UploadFromPath 将本机已有文件上传到用户的文件系统 -func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error { - file, err := os.Open(util.RelativePath(src)) - if err != nil { - return err - } - defer file.Close() - - // 获取源文件大小 - fi, err := file.Stat() - if err != nil { - return err - } - size := fi.Size() - - // 开始上传 - return fs.UploadFromStream(ctx, &fsctx.FileStream{ - File: file, - Seeker: file, - Size: uint64(size), - Name: path.Base(dst), - VirtualPath: path.Dir(dst), - Mode: mode, - }, true) -} diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go deleted file mode 100644 index 61dad9f8..00000000 --- a/pkg/filesystem/upload_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -type FileHeaderMock struct { - testMock.Mock -} - -func (m FileHeaderMock) Put(ctx context.Context, file fsctx.FileHeader) error { - args := m.Called(ctx, file) - return args.Error(0) -} - -func (m FileHeaderMock) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - args := m.Called(ctx, ttl, uploadSession, file) - return args.Get(0).(*serializer.UploadCredential), args.Error(1) -} - -func (m FileHeaderMock) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - args := m.Called(ctx, uploadSession) - return args.Error(0) -} - -func (m FileHeaderMock) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - args := m.Called(ctx, path, recursive) - return args.Get(0).([]response.Object), args.Error(1) -} - -func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) { - args := m.Called(ctx, path) - return args.Get(0).(response.RSCloser), args.Error(1) -} - -func (m FileHeaderMock) Delete(ctx context.Context, files []string) ([]string, error) { - args := m.Called(ctx, files) - return args.Get(0).([]string), args.Error(1) -} - -func (m FileHeaderMock) Thumb(ctx context.Context, files *model.File) (*response.ContentResponse, error) { - args := m.Called(ctx, files) - return args.Get(0).(*response.ContentResponse), args.Error(1) -} - -func (m FileHeaderMock) Source(ctx context.Context, path string, expires int64, isDownload bool, speed int) (string, error) { - args := m.Called(ctx, path, expires, isDownload, speed) - return args.Get(0).(string), args.Error(1) -} - -func TestFileSystem_Upload(t *testing.T) { - asserts := assert.New(t) - - // 正常 - testHandler := new(FileHeaderMock) - testHandler.On("Put", testMock.Anything, testMock.Anything, testMock.Anything).Return(nil) - fs := &FileSystem{ - Handler: testHandler, - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{ - AutoRename: false, - DirNameRule: "{path}", - }, - } - ctx, cancel := context.WithCancel(context.Background()) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - cancel() - file := &fsctx.FileStream{ - Size: 5, - VirtualPath: "/", - Name: "1.txt", - } - err := fs.Upload(ctx, file) - asserts.NoError(err) - - // 正常,上下文已指定源文件 - testHandler = new(FileHeaderMock) - testHandler.On("Put", testMock.Anything, testMock.Anything).Return(nil) - fs = &FileSystem{ - Handler: testHandler, - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{ - AutoRename: false, - DirNameRule: "{path}", - }, - } - ctx, cancel = context.WithCancel(context.Background()) - c, _ = gin.CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{SourceName: "123/123.txt"}) - cancel() - file = &fsctx.FileStream{ - Size: 5, - VirtualPath: "/", - Name: "1.txt", - File: ioutil.NopCloser(strings.NewReader("")), - } - err = fs.Upload(ctx, file) - asserts.NoError(err) - - // BeforeUpload 返回错误 - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - err = fs.Upload(ctx, file) - asserts.Error(err) - fs.Hooks["BeforeUpload"] = nil - testHandler.AssertExpectations(t) - - // 上传文件失败 - testHandler2 := new(FileHeaderMock) - testHandler2.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error")) - fs.Handler = testHandler2 - err = fs.Upload(ctx, file) - asserts.Error(err) - testHandler2.AssertExpectations(t) - - // AfterUpload失败 - testHandler3 := new(FileHeaderMock) - testHandler3.On("Put", testMock.Anything, testMock.Anything).Return(nil) - fs.Handler = testHandler3 - fs.Use("AfterUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - err = fs.Upload(ctx, file) - asserts.Error(err) - testHandler2.AssertExpectations(t) - -} - -func TestFileSystem_GetUploadToken(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - Policy: &model.Policy{}, - } - ctx := context.Background() - - // 成功 - { - cache.SetSettings(map[string]string{ - "upload_session_timeout": "10", - }, "setting_") - testHandler := new(FileHeaderMock) - testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{Credential: "test"}, nil) - fs.Handler = testHandler - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - res, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{ - Size: 0, - Name: "file", - VirtualPath: "/", - }) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.NoError(err) - asserts.Equal("test", res.Credential) - } - - // 无法获取上传凭证 - { - cache.SetSettings(map[string]string{ - "upload_credential_timeout": "10", - "upload_session_timeout": "10", - }, "setting_") - testHandler := new(FileHeaderMock) - testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{}, errors.New("error")) - fs.Handler = testHandler - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{ - Size: 0, - Name: "file", - VirtualPath: "/", - }) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.Error(err) - } -} - -func TestFileSystem_UploadFromStream(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - Policy: model.Policy{Type: "mock"}, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - err := fs.UploadFromStream(ctx, &fsctx.FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - }, true) - asserts.Error(err) -} - -func TestFileSystem_UploadFromPath(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - Policy: model.Policy{Type: "mock"}, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - // 文件不存在 - { - err := fs.UploadFromPath(ctx, "test/not_exist", "/", fsctx.Overwrite) - asserts.Error(err) - } - - // 文存在,上传失败 - { - err := fs.UploadFromPath(ctx, "tests/test.zip", "/", fsctx.Overwrite) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/validator.go b/pkg/filesystem/validator.go deleted file mode 100644 index 1992547e..00000000 --- a/pkg/filesystem/validator.go +++ /dev/null @@ -1,66 +0,0 @@ -package filesystem - -import ( - "context" - "strings" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -/* ========== - 验证器 - ========== -*/ - -// 文件/路径名保留字符 -var reservedCharacter = []string{"\\", "?", "*", "<", "\"", ":", ">", "/", "|"} - -// ValidateLegalName 验证文件名/文件夹名是否合法 -func (fs *FileSystem) ValidateLegalName(ctx context.Context, name string) bool { - // 是否包含保留字符 - for _, value := range reservedCharacter { - if strings.Contains(name, value) { - return false - } - } - - // 是否超出长度限制 - if len(name) >= 256 { - return false - } - - // 是否为空限制 - if len(name) == 0 { - return false - } - - // 结尾不能是空格 - if strings.HasSuffix(name, " ") { - return false - } - - return true -} - -// ValidateFileSize 验证上传的文件大小是否超出限制 -func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool { - if fs.Policy.MaxSize == 0 { - return true - } - return size <= fs.Policy.MaxSize -} - -// ValidateCapacity 验证并扣除用户容量 -func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool { - return fs.User.IncreaseStorage(size) -} - -// ValidateExtension 验证文件扩展名 -func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bool { - // 不需要验证 - if len(fs.Policy.OptionsSerialized.FileType) == 0 { - return true - } - - return util.IsInExtensionList(fs.Policy.OptionsSerialized.FileType, fileName) -} diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go deleted file mode 100644 index 8f685f27..00000000 --- a/pkg/filesystem/validator_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package filesystem - -import ( - "context" - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestFileSystem_ValidateLegalName(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{} - asserts.True(fs.ValidateLegalName(ctx, "1.txt")) - asserts.True(fs.ValidateLegalName(ctx, "1-1.txt")) - asserts.True(fs.ValidateLegalName(ctx, "1?1.txt")) - asserts.False(fs.ValidateLegalName(ctx, "1:1.txt")) - asserts.False(fs.ValidateLegalName(ctx, "../11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "/11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "\\11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "")) - asserts.False(fs.ValidateLegalName(ctx, "1.tx t ")) - asserts.True(fs.ValidateLegalName(ctx, "1.tx t")) -} - -func TestFileSystem_ValidateCapacity(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - cache.Set("pack_size_0", uint64(0), 0) - fs := FileSystem{ - User: &model.User{ - Storage: 10, - Group: model.Group{ - MaxStorage: 11, - }, - }, - } - - asserts.True(fs.ValidateCapacity(ctx, 1)) - asserts.Equal(uint64(11), fs.User.Storage) - - fs.User.Storage = 5 - asserts.False(fs.ValidateCapacity(ctx, 10)) - asserts.Equal(uint64(5), fs.User.Storage) -} - -func TestFileSystem_ValidateFileSize(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - MaxSize: 10, - }, - } - - asserts.True(fs.ValidateFileSize(ctx, 5)) - asserts.True(fs.ValidateFileSize(ctx, 10)) - asserts.False(fs.ValidateFileSize(ctx, 11)) - - // 无限制 - fs.Policy.MaxSize = 0 - asserts.True(fs.ValidateFileSize(ctx, 11)) -} - -func TestFileSystem_ValidateExtension(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - OptionsSerialized: model.PolicyOption{ - FileType: nil, - }, - }, - } - - asserts.True(fs.ValidateExtension(ctx, "1")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - - fs.Policy.OptionsSerialized.FileType = []string{} - asserts.True(fs.ValidateExtension(ctx, "1")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - - fs.Policy.OptionsSerialized.FileType = []string{"txt", "jpg"} - asserts.False(fs.ValidateExtension(ctx, "1")) - asserts.False(fs.ValidateExtension(ctx, "1.jpg.png")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - asserts.True(fs.ValidateExtension(ctx, "1.png.jpg")) - asserts.True(fs.ValidateExtension(ctx, "1.png.jpG")) - asserts.False(fs.ValidateExtension(ctx, "1.png")) -} diff --git a/pkg/hashid/hash.go b/pkg/hashid/hash.go index ffe59441..9f93c6e6 100644 --- a/pkg/hashid/hash.go +++ b/pkg/hashid/hash.go @@ -1,11 +1,10 @@ package hashid import ( + "context" "errors" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/speps/go-hashids" ) +import "github.com/speps/go-hashids" // ID类型 const ( @@ -16,6 +15,13 @@ const ( TagID // 标签ID PolicyID // 存储策略ID SourceLinkID + GroupID + EntityID + AuditLogID + NodeID + TaskID + DavAccountID + PaymentID ) var ( @@ -23,48 +29,124 @@ var ( ErrTypeNotMatch = errors.New("mismatched ID type.") ) -// HashEncode 对给定数据计算HashID -func HashEncode(v []int) (string, error) { - hd := hashids.NewData() - hd.Salt = conf.SystemConfig.HashIDSalt +type Encoder interface { + Encode(v []int) (string, error) + Decode(raw string, t int) (int, error) +} + +// ObjectIDCtx define key for decoded hash ID. +type ( + ObjectIDCtx struct{} + EncodeFunc func(encoder Encoder, uid int) string +) +type hashEncoder struct { + h *hashids.HashID +} + +func New(salt string) (Encoder, error) { + hd := hashids.NewData() + hd.Salt = salt h, err := hashids.NewWithData(hd) if err != nil { - return "", err + return nil, err } - id, err := h.Encode(v) + return &hashEncoder{h: h}, nil +} + +func (e *hashEncoder) Encode(v []int) (string, error) { + id, err := e.h.Encode(v) if err != nil { return "", err } return id, nil } -// HashDecode 对给定数据计算原始数据 -func HashDecode(raw string) ([]int, error) { - hd := hashids.NewData() - hd.Salt = conf.SystemConfig.HashIDSalt - - h, err := hashids.NewWithData(hd) +func (e *hashEncoder) Decode(raw string, t int) (int, error) { + res, err := e.h.DecodeWithError(raw) if err != nil { - return []int{}, err + return 0, err } - return h.DecodeWithError(raw) + if len(res) != 2 || res[1] != t { + return 0, ErrTypeNotMatch + } + return res[0], nil +} +// EncodeUserID encode user id to hash id +func EncodeUserID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, UserID}) + return res } -// HashID 计算数据库内主键对应的HashID -func HashID(id uint, t int) string { - v, _ := HashEncode([]int{int(id), t}) - return v +// EncodeGroupID encode group id to hash id +func EncodeGroupID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, GroupID}) + return res } -// DecodeHashID 计算HashID对应的数据库ID -func DecodeHashID(id string, t int) (uint, error) { - v, _ := HashDecode(id) - if len(v) != 2 || v[1] != t { - return 0, ErrTypeNotMatch - } - return uint(v[0]), nil +// EncodePaymentID encode payment id to hash id +func EncodePaymentID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, PaymentID}) + return res +} + +// EncodeFileID encode file id to hash id +func EncodeFileID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, FileID}) + return res +} + +// EncodeAuditLogID encode audit log id to hash id +func EncodeAuditLogID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, AuditLogID}) + return res +} + +// EncodeTaskID encode task id to hash id +func EncodeTaskID(encoder Encoder, uid int) string { + res, _ := encoder.Encode([]int{uid, TaskID}) + return res +} + +// EncodeEntityID encode entity id to hash id +func EncodeEntityID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, EntityID}) + return res +} + +// EncodeNodeID encode node id to hash id +func EncodeNodeID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, NodeID}) + return res +} + +// EncodeEntityID encode policy id to hash id +func EncodePolicyID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, PolicyID}) + return res +} + +// EncodeEntityID encode share id to hash id +func EncodeShareID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, ShareID}) + return res +} + +// EncodeDavAccountID encode dav account id to hash id +func EncodeDavAccountID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, DavAccountID}) + return res +} + +// EncodeSourceLinkID encode source link id to hash id +func EncodeSourceLinkID(encoder Encoder, id int) string { + res, _ := encoder.Encode([]int{id, SourceLinkID}) + return res +} + +func FromContext(c context.Context) int { + return c.Value(ObjectIDCtx{}).(int) } diff --git a/pkg/hashid/hash_test.go b/pkg/hashid/hash_test.go deleted file mode 100644 index 5471d9ee..00000000 --- a/pkg/hashid/hash_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package hashid - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestHashEncode(t *testing.T) { - asserts := assert.New(t) - - { - res, err := HashEncode([]int{1, 2, 3}) - asserts.NoError(err) - asserts.NotEmpty(res) - } - - { - res, err := HashEncode([]int{}) - asserts.Error(err) - asserts.Empty(res) - } - -} - -func TestHashID(t *testing.T) { - asserts := assert.New(t) - - { - res := HashID(1, ShareID) - asserts.NotEmpty(res) - } -} - -func TestHashDecode(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - res, _ := HashEncode([]int{1, 2, 3}) - decodeRes, err := HashDecode(res) - asserts.NoError(err) - asserts.Equal([]int{1, 2, 3}, decodeRes) - } - - // 出错 - { - decodeRes, err := HashDecode("233") - asserts.Error(err) - asserts.Len(decodeRes, 0) - } -} - -func TestDecodeHashID(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - uid, err := DecodeHashID(HashID(1, ShareID), ShareID) - asserts.NoError(err) - asserts.EqualValues(1, uid) - } - - // 类型不匹配 - { - uid, err := DecodeHashID(HashID(1, ShareID), UserID) - asserts.Error(err) - asserts.EqualValues(0, uid) - } -} diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go new file mode 100644 index 00000000..914fe5d9 --- /dev/null +++ b/pkg/logging/logger.go @@ -0,0 +1,205 @@ +package logging + +import ( + "context" + "fmt" + "github.com/fatih/color" + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "runtime" + "time" +) + +// Logger interface for logging messages. +type Logger interface { + Panic(format string, v ...any) + Error(format string, v ...any) + Warning(format string, v ...any) + Info(format string, v ...any) + Debug(format string, v ...any) + // Copy a new logger with a prefix. + CopyWithPrefix(prefix string) Logger + + // SupportColor returns if current logger support outputting colors. + SupportColor() bool +} + +// LoggerCtx defines keys for logger with correlation ID +type LoggerCtx struct{} + +// CorrelationIDCtx defines keys for correlation ID +type CorrelationIDCtx struct{} +type LogLevel string + +const ( + // LevelError 错误 + LevelError LogLevel = "error" + // LevelWarning 警告 + LevelWarning LogLevel = "warning" + // LevelInformational 提示 + LevelInformational LogLevel = "info" + // LevelDebug 除错 + LevelDebug LogLevel = "debug" +) + +// NewConsoleLogger initializes a new logging that prints logs to Stdout. +func NewConsoleLogger(level LogLevel) Logger { + logFunc := func(level string) loggingFunc { + return func(logger *consoleLogger, s string, a ...any) { + msg := fmt.Sprintf(s, a...) + logger.println(level, msg) + } + } + + logger := &consoleLogger{ + warning: logFunc("Warn"), + panic: func(logger *consoleLogger, s string, a ...any) { + msg := fmt.Sprintf(s, a...) + logger.println("Panic", msg) + panic(msg) + }, + error: logFunc("Error"), + info: logFunc("Info"), + debug: logFunc("Debug"), + } + + switch level { + case LevelError: + logger.warning = noopLoggingFunc + logger.info = noopLoggingFunc + logger.debug = noopLoggingFunc + case LevelWarning: + logger.info = noopLoggingFunc + logger.debug = noopLoggingFunc + case LevelInformational: + logger.debug = noopLoggingFunc + case LevelDebug: + + } + + return logger +} + +// FromContext retrieves a logger from context. +func FromContext(ctx context.Context) Logger { + v, ok := ctx.Value(LoggerCtx{}).(Logger) + if !ok { + v = NewConsoleLogger(LevelDebug) + } + return v +} + +// CorrelationID retrieves a correlation ID from context. +func CorrelationID(ctx context.Context) uuid.UUID { + v, ok := ctx.Value(CorrelationIDCtx{}).(uuid.UUID) + if !ok { + v = uuid.Nil + } + return v +} + +type consoleLogger struct { + warning loggingFunc + panic loggingFunc + error loggingFunc + info loggingFunc + debug loggingFunc + prefix string +} + +func (ll *consoleLogger) Panic(format string, v ...any) { + ll.panic(ll, format, v...) +} + +func (ll *consoleLogger) Error(format string, v ...any) { + ll.error(ll, format, v...) +} + +func (ll *consoleLogger) Warning(format string, v ...any) { + ll.warning(ll, format, v...) +} + +func (ll *consoleLogger) Info(format string, v ...any) { + ll.info(ll, format, v...) +} + +func (ll *consoleLogger) Debug(format string, v ...any) { + ll.debug(ll, format, v...) +} + +// println 打印 +func (ll *consoleLogger) println(level string, msg string) { + c := color.New() + _, filename, line, _ := runtime.Caller(3) + + _, _ = c.Printf( + "%s\t %s [%s:%d]%s %s\n", + colors[level]("["+level+"]"), + time.Now().Format("2006-01-02 15:04:05"), + filename, + line, + ll.prefix, + msg, + ) +} + +func (ll *consoleLogger) CopyWithPrefix(prefix string) Logger { + return &consoleLogger{ + warning: ll.warning, + panic: ll.panic, + error: ll.error, + info: ll.info, + debug: ll.debug, + prefix: ll.prefix + " " + prefix, + } +} + +func (ll *consoleLogger) SupportColor() bool { + return !color.NoColor +} + +type loggingFunc func(*consoleLogger, string, ...any) + +func noopLoggingFunc(*consoleLogger, string, ...any) {} + +var colors = map[string]func(a ...interface{}) string{ + "Warn": color.New(color.FgYellow).Add(color.Bold).SprintFunc(), + "Panic": color.New(color.BgRed).Add(color.Bold).SprintFunc(), + "Error": color.New(color.FgRed).Add(color.Bold).SprintFunc(), + "Info": color.New(color.FgCyan).Add(color.Bold).SprintFunc(), + "Debug": color.New(color.FgWhite).Add(color.Bold).SprintFunc(), +} + +// Request helper fund to log request. +func Request(l Logger, incoming bool, code int, method, clientIP, path, err string, start time.Time) { + param := gin.LogFormatterParams{ + StatusCode: code, + Method: method, + } + param.StatusCode = code + + var statusColor, methodColor, resetColor string + if l.SupportColor() { + statusColor = param.StatusCodeColor() + methodColor = param.MethodColor() + resetColor = param.ResetColor() + } + + category := "Incoming" + if !incoming { + category = "Outgoing" + } + + l.Info( + "[%s] %s %3d %s| %13v | %15s |%s %-7s %s %#v", + category, + statusColor, param.StatusCode, resetColor, + time.Now().Sub(start), + clientIP, + methodColor, method, resetColor, + path, + ) + if err != "" { + l.Error("%s", err) + } +} diff --git a/pkg/mediameta/exif.go b/pkg/mediameta/exif.go new file mode 100644 index 00000000..27e0d213 --- /dev/null +++ b/pkg/mediameta/exif.go @@ -0,0 +1,924 @@ +package mediameta + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "regexp" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/dsoprea/go-exif/v3" + exifcommon "github.com/dsoprea/go-exif/v3/common" + heicexif "github.com/dsoprea/go-heic-exif-extractor" + jpegstructure "github.com/dsoprea/go-jpeg-image-structure" + pngstructure "github.com/dsoprea/go-png-image-structure" + tiffstructure "github.com/dsoprea/go-tiff-image-structure" + riimage "github.com/dsoprea/go-utility/image" +) + +var ( + exifExts = []string{ + "jpg", + "jpeg", + "png", + "heic", + "heif", + "tiff", + "avif", + // R + "3fr", "ari", "arw", "bay", "braw", "crw", "cr2", "cr3", "cap", "data", "dcs", "dcr", "dng", "drf", "eip", "erf", "fff", "gpr", "iiq", "k25", "kdc", "mdc", "mef", "mos", "mrw", "nef", "nrw", "obm", "orf", "pef", "ptx", "pxn", "r3d", "raf", "raw", "rwl", "rw2", "rwz", "sr2", "srf", "srw", "tif", "x3f", + } + exifIfdMapping *exifcommon.IfdMapping + exifTagIndex = exif.NewTagIndex() + exifDateTimeTags = []string{"DateTimeOriginal", "DateTimeCreated", "CreateDate", "DateTime", "DateTimeDigitized"} + ExifDateTimeMatch = make(map[string]int) + ExifDateTimeRegexp = regexp.MustCompile("((?P\\d{4})|\\D{4})\\D((?P\\d{2})|\\D{2})\\D((?P\\d{2})|\\D{2})\\D((?P\\d{2})|\\D{2})\\D((?P\\d{2})|\\D{2})\\D((?P\\d{2})|\\D{2})(\\.(?P\\d+))?(?P\\D)?(?P\\d{2})?\\D?(?P\\d{2})?") + YearMax = time.Now().Add(OneYear * 3).Year() + UnwantedDescriptions = map[string]bool{ + "Created by Imlib": true, // Apps + "iClarified": true, + "OLYMPUS DIGITAL CAMERA": true, // Olympus + "SAMSUNG": true, // Samsung + "SAMSUNG CAMERA PICTURES": true, + "": true, + "SONY DSC": true, // Sony + "rhdr": true, // Huawei + "hdrpl": true, + "oznorWO": true, + "frontbhdp": true, + "fbt": true, + "rbt": true, + "ptr": true, + "fbthdr": true, + "btr": true, + "mon": true, + "nor": true, + "dav": true, + "mde": true, + "mde_soft": true, + "edf": true, + "btfmdn": true, + "btf": true, + "btfhdr": true, + "frem": true, + "oznor": true, + "rpt": true, + "burst": true, + "sdr_HDRB": true, + "cof": true, + "qrf": true, + "fshbty": true, + "binary comment": true, // Other + "default": true, + "Exif_JPEG_PICTURE": true, + "DVC 10.1 HDMI": true, + "charset=Ascii": true, + } +) + +const ( + OneYear = time.Hour * 24 * 365 + LatMax = 90 + LngMax = 180 + + GpsLat = "latitude" + GpsLng = "longitude" + GpsAttitude = "altitude" + Artist = "artist" + Copyright = "copyright" + CameraModel = "camera_model" + CameraMake = "camera_make" + CameraOwnerName = "camera_owner" + BodySerialNumber = "body_serial" + LensMake = "lens_make" + LensModel = "lens_model" + Software = "software" + ExposureTime = "exposure_time" + FNumber = "f" + ApertureValue = "aperture" + FocalLength = "focal_length" + ISOSpeedRatings = "iso" + PixelXDimension = "x" + PixelYDimension = "y" + Orientation = "orientation" + TakenAt = "taken_at" + Flash = "flash" + ImageDescription = "des" + ProjectionType = "projection_type" + ExposureBiasValue = "exposure_bias" +) + +func init() { + exifIfdMapping = exifcommon.NewIfdMapping() + _ = exifcommon.LoadStandardIfds(exifIfdMapping) + names := ExifDateTimeRegexp.SubexpNames() + for i := 0; i < len(names); i++ { + if name := names[i]; name != "" { + ExifDateTimeMatch[name] = i + } + } +} + +type exifExtractor struct { + settings setting.Provider + l logging.Logger +} + +func newExifExtractor(settings setting.Provider, l logging.Logger) *exifExtractor { + return &exifExtractor{ + settings: settings, + l: l, + } +} + +func (e *exifExtractor) Exts() []string { + return exifExts +} + +// Reference: https://github.com/photoprism/photoprism/blob/602097635f1c84d91f2d919f7aedaef7a07fc458/internal/meta/exif.go +func (e *exifExtractor) Extract(ctx context.Context, ext string, source entitysource.EntitySource) ([]driver.MediaMeta, error) { + localLimit, remoteLimit := e.settings.MediaMetaExifSizeLimit(ctx) + if err := checkFileSize(localLimit, remoteLimit, source); err != nil { + return nil, err + } + + bruteForce := e.settings.MediaMetaExifBruteForce(ctx) + var ( + err error + exifData []byte + ) + parser := getExifParser(ext) + if parser == nil { + if !bruteForce { + return nil, errors.New("no available exif parser found") + } + + } else { + var res riimage.MediaContext + res, err = parser.Parse(source, int(source.Entity().Size())) + if err != nil { + err = fmt.Errorf("failed to parse exif: %s", err) + } else { + _, exifData, err = res.Exif() + if err != nil { + err = fmt.Errorf("failed to parse exif root: %s", err) + } + } + } + + if !bruteForce && err != nil { + return nil, err + } else if bruteForce && (err != nil || parser == nil) { + e.l.Debug("Failed to parse exif: %s, trying brute force.", err) + exifData, err = exif.SearchAndExtractExifWithReader(source) + if err != nil { + if errors.Is(err, exif.ErrNoExif) { + e.l.Debug("No exif data found") + return nil, nil + } + + return nil, fmt.Errorf("failed to brute force to parse exif: %s", err) + } + } + + entries, _, err := exif.GetFlatExifData(exifData, &exif.ScanOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to parse exif entries: %s", err) + } + + exifMap := make(map[string]string, len(entries)) + for _, tag := range entries { + s := strings.Split(tag.FormattedFirst, "\x00") + if tag.TagName == "" || len(s) == 0 { + } else if s[0] != "" && (exifMap[tag.TagName] == "" || tag.IfdPath != exif.ThumbnailFqIfdPath) { + exifMap[tag.TagName] = s[0] + } + } + + if len(exifMap) == 0 { + return nil, errors.New("no exif data found") + } + + metas := make([]driver.MediaMeta, 0) + takenTimeGps := time.Time{} + + // Extract GPS info + var ifdIndex exif.IfdIndex + _, ifdIndex, err = exif.Collect(exifIfdMapping, exifTagIndex, exifData) + if err != nil { + e.l.Debug("Failed to collect exif data: %s", err) + } else { + var ifd *exif.Ifd + if ifd, err = ifdIndex.RootIfd.ChildWithIfdPath(exifcommon.IfdGpsInfoStandardIfdIdentity); err == nil { + var gi *exif.GpsInfo + if gi, err = ifd.GpsInfo(); err != nil { + e.l.Debug("Failed to collect exif gps data: %s", err) + } else { + if !math.IsNaN(gi.Latitude.Decimal()) && !math.IsNaN(gi.Longitude.Decimal()) { + lat, lng := NormalizeGPS(gi.Latitude.Decimal(), gi.Longitude.Decimal()) + metas = append(metas, driver.MediaMeta{ + Key: GpsLat, + Value: fmt.Sprintf("%f", lat), + }, driver.MediaMeta{ + Key: GpsLng, + Value: fmt.Sprintf("%f", lng), + }) + } else if gi.Altitude != 0 || !gi.Timestamp.IsZero() { + e.l.Warning("GPS data is invalid: %s", gi.String()) + } + + if gi.Altitude != 0 { + metas = append(metas, driver.MediaMeta{ + Key: GpsAttitude, + Value: fmt.Sprintf("%d", gi.Altitude), + }) + } + + if !gi.Timestamp.IsZero() { + takenTimeGps = gi.Timestamp + } + } + } + } + + metas = append(metas, ExtractExifMap(exifMap, takenTimeGps)...) + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MetaTypeExif + } + + return metas, nil +} + +func ExtractExifMap(exifMap map[string]string, gpsTime time.Time) []driver.MediaMeta { + metas := make([]driver.MediaMeta, 0) + if value, ok := exifMap["Artist"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: Artist, + Value: SanitizeMeta(value), + }) + } + + if value, ok := exifMap["Copyright"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: Copyright, + Value: SanitizeString(value), + }) + } + + cameraMode := "" + if value, ok := exifMap["CameraModel"]; ok && !IsUInt(value) { + cameraMode = SanitizeString(value) + } else if value, ok = exifMap["Model"]; ok && !IsUInt(value) { + cameraMode = SanitizeString(value) + } else if value, ok = exifMap["UniqueCameraModel"]; ok && !IsUInt(value) { + cameraMode = SanitizeString(value) + } + if cameraMode != "" { + metas = append(metas, driver.MediaMeta{ + Key: CameraModel, + Value: cameraMode, + }) + } + + cameraMake := "" + if value, ok := exifMap["CameraMake"]; ok && !IsUInt(value) { + cameraMake = SanitizeString(value) + } else if value, ok = exifMap["Make"]; ok && !IsUInt(value) { + cameraMake = SanitizeString(value) + } + if cameraMake != "" { + metas = append(metas, driver.MediaMeta{ + Key: CameraMake, + Value: cameraMake, + }) + } + + if value, ok := exifMap["CameraOwnerName"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: CameraOwnerName, + Value: SanitizeString(value), + }) + } + + if value, ok := exifMap["BodySerialNumber"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: BodySerialNumber, + Value: SanitizeString(value), + }) + } + + if value, ok := exifMap["LensMake"]; ok && !IsUInt(value) { + metas = append(metas, driver.MediaMeta{ + Key: LensMake, + Value: SanitizeString(value), + }) + } + + lens := "" + if value, ok := exifMap["LensModel"]; ok && !IsUInt(value) { + lens = SanitizeString(value) + } else if value, ok = exifMap["Lens"]; ok && !IsUInt(value) { + lens = SanitizeString(value) + } + if lens != "" { + metas = append(metas, driver.MediaMeta{ + Key: LensModel, + Value: lens, + }) + } + + if value, ok := exifMap["Software"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: Software, + Value: SanitizeString(value), + }) + } + + if value, ok := exifMap["ExposureTime"]; ok { + value = strings.TrimSuffix(value, " sec.") + if n := strings.Split(value, "/"); len(n) == 2 { + if n[0] != "1" && len(n[0]) < len(n[1]) { + n0, _ := strconv.ParseUint(n[0], 10, 64) + if n1, err := strconv.ParseUint(n[1], 10, 64); err == nil && n0 > 0 && n1 > 0 { + value = fmt.Sprintf("1/%d", n1/n0) + } + } + } + + metas = append(metas, driver.MediaMeta{ + Key: ExposureTime, + Value: value, + }) + } + + if value, ok := exifMap["ExposureBiasValue"]; ok { + if n := strings.Split(value, "/"); len(n) == 2 { + n0, _ := strconv.ParseInt(n[0], 10, 64) + if n1, err := strconv.ParseInt(n[1], 10, 64); err == nil { + v := "0" + v = fmt.Sprintf("%f", float64(n0)/float64(n1)) + metas = append(metas, driver.MediaMeta{ + Key: ExposureBiasValue, + Value: v, + }) + } + } + } + + if value, ok := exifMap["FNumber"]; ok { + values := strings.Split(value, "/") + + if len(values) == 2 && values[1] != "0" && values[1] != "" { + number, _ := strconv.ParseFloat(values[0], 64) + denom, _ := strconv.ParseFloat(values[1], 64) + + metas = append(metas, driver.MediaMeta{ + Key: FNumber, + Value: fmt.Sprintf("%f", float32(math.Round((number/denom)*1000)/1000)), + }) + } + } + + if value, ok := exifMap["ApertureValue"]; ok { + values := strings.Split(value, "/") + + if len(values) == 2 && values[1] != "0" && values[1] != "" { + number, _ := strconv.ParseFloat(values[0], 64) + denom, _ := strconv.ParseFloat(values[1], 64) + + metas = append(metas, driver.MediaMeta{ + Key: ApertureValue, + Value: fmt.Sprintf("%f", float32(math.Round((number/denom)*1000)/1000)), + }) + } + } + + focalLength := "" + if value, ok := exifMap["FocalLengthIn35mmFilm"]; ok { + focalLength = value + } else if v, ok := exifMap["FocalLength"]; ok { + values := strings.Split(v, "/") + + if len(values) == 2 && values[1] != "0" && values[1] != "" { + number, _ := strconv.ParseFloat(values[0], 64) + denom, _ := strconv.ParseFloat(values[1], 64) + + focalLength = strconv.Itoa(int(math.Round((number/denom)*1000) / 1000)) + } + } + if focalLength != "" { + metas = append(metas, driver.MediaMeta{ + Key: FocalLength, + Value: focalLength, + }) + } + + if value, ok := exifMap["ISOSpeedRatings"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: ISOSpeedRatings, + Value: value, + }) + } + + width := "" + if value, ok := exifMap["PixelXDimension"]; ok { + width = value + } else if value, ok := exifMap["ImageWidth"]; ok { + width = value + } + if width != "" { + metas = append(metas, driver.MediaMeta{ + Key: PixelXDimension, + Value: width, + }) + } + + height := "" + if value, ok := exifMap["PixelYDimension"]; ok { + height = value + } else if value, ok := exifMap["ImageLength"]; ok { + height = value + } + if height != "" { + metas = append(metas, driver.MediaMeta{ + Key: PixelYDimension, + Value: height, + }) + } + + orientation := "1" + if value, ok := exifMap["Orientation"]; ok { + orientation = value + } + metas = append(metas, driver.MediaMeta{ + Key: Orientation, + Value: orientation, + }) + + takeTime := time.Time{} + for _, name := range exifDateTimeTags { + if dateTime := DateTime(exifMap[name], ""); !dateTime.IsZero() { + takeTime = dateTime + break + } + } + if takeTime.IsZero() { + takeTime = gpsTime.UTC() + } + + if !takeTime.IsZero() { + metas = append(metas, driver.MediaMeta{ + Key: TakenAt, + Value: takeTime.Format(time.RFC3339), + }) + } + + if value, ok := exifMap["Flash"]; ok { + flash := "0" + if i, err := strconv.Atoi(value); err == nil && i&1 == 1 { + flash = "1" + } + metas = append(metas, driver.MediaMeta{ + Key: Flash, + Value: flash, + }) + } + + if value, ok := exifMap["ImageDescription"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: ImageDescription, + Value: SanitizeDescription(value), + }) + } + + if value, ok := exifMap["ProjectionType"]; ok { + metas = append(metas, driver.MediaMeta{ + Key: ProjectionType, + Value: SanitizeString(value), + }) + } + + return metas +} + +type ( + exifParser interface { + Parse(rs io.ReadSeeker, size int) (ec riimage.MediaContext, err error) + } +) + +func getExifParser(ext string) exifParser { + switch ext { + case "jpg", "jpeg": + return jpegstructure.NewJpegMediaParser() + case "png": + return pngstructure.NewPngMediaParser() + case "tiff": + return tiffstructure.NewTiffMediaParser() + case "heic", "heif", "avif": + return heicexif.NewHeicExifMediaParser() + default: + return nil + } +} + +// NormalizeGPS normalizes the longitude and latitude of the GPS position to a generally valid range. +func NormalizeGPS(lat, lng float64) (float32, float32) { + if lat < LatMax || lat > LatMax || lng < LngMax || lng > LngMax { + // Clip the latitude. Normalise the longitude. + lat, lng = clipLat(lat), normalizeLng(lng) + } + + return float32(lat), float32(lng) +} + +func clipLat(lat float64) float64 { + if lat > LatMax*2 { + return math.Mod(lat, LatMax) + } else if lat > LatMax { + return lat - LatMax + } + + if lat < -LatMax*2 { + return math.Mod(lat, LatMax) + } else if lat < -LatMax { + return lat + LatMax + } + + return lat +} + +func normalizeLng(value float64) float64 { + return normalizeCoord(value, LngMax) +} + +func normalizeCoord(value, max float64) float64 { + for value < -max { + value += 2 * max + } + for value >= max { + value -= 2 * max + } + return value +} + +// SanitizeString removes unwanted character from an exif value string. +func SanitizeString(s string) string { + if s == "" { + return "" + } + + if strings.HasPrefix(s, "string with binary data") { + return "" + } else if strings.HasPrefix(s, "(Binary data") { + return "" + } + + return SanitizeUnicode(strings.Replace(s, "\"", "", -1)) +} + +// SanitizeUnicode returns the string as valid Unicode with whitespace trimmed. +func SanitizeUnicode(s string) string { + if s == "" { + return "" + } + + return unicode(strings.TrimSpace(s)) +} + +// SanitizeMeta normalizes metadata fields that may contain JSON arrays like keywords and subject. +func SanitizeMeta(s string) string { + if s == "" { + return "" + } + + if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") { + var words []string + + if err := json.Unmarshal([]byte(s), &words); err != nil { + return s + } + + s = strings.Join(words, ", ") + } else { + s = SanitizeString(s) + } + + return s +} + +func unicode(s string) string { + if s == "" { + return "" + } + + var b strings.Builder + + for _, c := range s { + if c == '\uFFFD' { + continue + } + b.WriteRune(c) + } + + return b.String() +} + +func IsUInt(s string) bool { + if s == "" { + return false + } + + for _, r := range s { + if r < 48 || r > 57 { + return false + } + } + + return true +} + +// DateTime parses a time string and returns a valid time.Time if possible. +func DateTime(s, timeZone string) (t time.Time) { + defer func() { + if r := recover(); r != nil { + // Panic? Return unknown time. + t = time.Time{} + } + }() + + // Ignore defaults. + if DateTimeDefault(s) { + return time.Time{} + } + + s = strings.TrimLeft(s, " ") + + // Timestamp too short? + if len(s) < 4 { + return time.Time{} + } else if len(s) > 50 { + // Clip to max length. + s = s[:50] + } + + // Pad short timestamp with whitespace at the end. + s = fmt.Sprintf("%-19s", s) + + v := ExifDateTimeMatch + m := ExifDateTimeRegexp.FindStringSubmatch(s) + + // Pattern doesn't match? Return unknown time. + if len(m) == 0 { + return time.Time{} + } + + // Default to UTC. + tz := time.UTC + + // Local time zone currently not supported (undefined). + if timeZone == time.Local.String() { + timeZone = "" + } + + // Set time zone. + loc := TimeZone(timeZone) + + // Location found? + if loc != nil && timeZone != "" && tz != time.Local { + tz = loc + timeZone = tz.String() + } else { + timeZone = "" + } + + // Does the timestamp contain a time zone offset? + z := m[v["z"]] // Supported values, if not empty: Z, +, - + zh := IntVal(m[v["zh"]], 0, 23, 0) // Hours. + zm := IntVal(m[v["zm"]], 0, 59, 0) // Minutes. + + // Valid time zone offset found? + if offset := (zh*60 + zm) * 60; offset > 0 && offset <= 86400 { + // Offset timezone name example: UTC+03:30 + if z == "+" { + // Positive offset relative to UTC. + tz = time.FixedZone(fmt.Sprintf("UTC+%02d:%02d", zh, zm), offset) + } else if z == "-" { + // Negative offset relative to UTC. + tz = time.FixedZone(fmt.Sprintf("UTC-%02d:%02d", zh, zm), -1*offset) + } + } + + var nsec int + + if subsec := m[v["subsec"]]; subsec != "" { + nsec = Int(subsec + strings.Repeat("0", 9-len(subsec))) + } else { + nsec = 0 + } + + // Create rounded timestamp from parsed input values. + // Year 0 is treated separately as it has a special meaning in exiftool. Golang + // does not seem to accept value 0 for the year, but considers a date to be + // "zero" when year is 1. + year := IntVal(m[v["year"]], 0, YearMax, time.Now().Year()) + if year == 0 { + year = 1 + } + t = time.Date( + year, + time.Month(IntVal(m[v["month"]], 1, 12, 1)), + IntVal(m[v["day"]], 1, 31, 1), + IntVal(m[v["h"]], 0, 23, 0), + IntVal(m[v["m"]], 0, 59, 0), + IntVal(m[v["s"]], 0, 59, 0), + nsec, + tz) + + if timeZone != "" && loc != nil && loc != tz { + return t.In(loc) + } + + return t +} + +// Int converts a string to a signed integer or 0 if invalid. +func Int(s string) int { + if s == "" { + return 0 + } + + result, err := strconv.ParseInt(strings.TrimSpace(s), 10, 32) + + if err != nil { + return 0 + } + + return int(result) +} + +// IntVal converts a string to a validated integer or a default if invalid. +func IntVal(s string, min, max, def int) (i int) { + if s == "" { + return def + } else if s[0] == ' ' { + s = strings.TrimSpace(s) + } + + result, err := strconv.ParseInt(s, 10, 32) + + if err != nil { + return def + } + + i = int(result) + + if i < min { + return def + } else if max != 0 && i > max { + return def + } + + return i +} + +// DateTimeDefault tests if the datetime string is not empty and not a default value. +func DateTimeDefault(s string) bool { + switch s { + case "1970-01-01", "1970-01-01 00:00:00", "1970:01:01 00:00:00": + // Unix epoch. + return true + case "1980-01-01", "1980-01-01 00:00:00", "1980:01:01 00:00:00": + // Windows default. + return true + case "2002-12-08 12:00:00", "2002:12:08 12:00:00": + // Android Bug: https://issuetracker.google.com/issues/36967504 + return true + default: + return EmptyDateTime(s) + } +} + +// EmptyDateTime tests if the string is empty or matches an unknown time pattern. +func EmptyDateTime(s string) bool { + switch s { + case "", "-", ":", "z", "Z", "nil", "null", "none", "nan", "NaN": + return true + case "0", "00", "0000", "0000:00:00", "00:00:00", "0000-00-00", "00-00-00": + return true + case " : : : : ", " - - - - ", " - - : : ": + // Exif default. + return true + case "0000:00:00 00:00:00", "0000-00-00 00-00-00", "0000-00-00 00:00:00": + return true + case "0001:01:01 00:00:00", "0001-01-01 00-00-00", "0001-01-01 00:00:00": + // Go default. + return true + case "0001:01:01 00:00:00 +0000 UTC", "0001-01-01 00-00-00 +0000 UTC", "0001-01-01 00:00:00 +0000 UTC": + // Go default with time zone. + return true + default: + return false + } +} + +// TimeZone returns a time zone for the given UTC offset string. +func TimeZone(offset string) *time.Location { + if offset == "" { + // Local time. + } else if offset == "UTC" || offset == "Z" { + return time.UTC + } else if seconds, err := TimeOffset(offset); err == nil { + if h := seconds / 3600; h > 0 || h < 0 { + return time.FixedZone(fmt.Sprintf("UTC%+d", h), seconds) + } + } else if zone, zoneErr := time.LoadLocation(offset); zoneErr == nil { + return zone + } + + return time.FixedZone("", 0) +} + +// TimeOffset returns the UTC time offset in seconds or an error if it is invalid. +func TimeOffset(utcOffset string) (seconds int, err error) { + switch utcOffset { + case "-12", "-12:00", "UTC-12", "UTC-12:00": + seconds = -12 * 3600 + case "-11", "-11:00", "UTC-11", "UTC-11:00": + seconds = -11 * 3600 + case "-10", "-10:00", "UTC-10", "UTC-10:00": + seconds = -10 * 3600 + case "-9", "-09", "-09:00", "UTC-9", "UTC-09:00": + seconds = -9 * 3600 + case "-8", "-08", "-08:00", "UTC-8", "UTC-08:00": + seconds = -8 * 3600 + case "-7", "-07", "-07:00", "UTC-7", "UTC-07:00": + seconds = -7 * 3600 + case "-6", "-06", "-06:00", "UTC-6", "UTC-06:00": + seconds = -6 * 3600 + case "-5", "-05", "-05:00", "UTC-5", "UTC-05:00": + seconds = -5 * 3600 + case "-4", "-04", "-04:00", "UTC-4", "UTC-04:00": + seconds = -4 * 3600 + case "-3", "-03", "-03:00", "UTC-3", "UTC-03:00": + seconds = -3 * 3600 + case "-2", "-02", "-02:00", "UTC-2", "UTC-02:00": + seconds = -2 * 3600 + case "-1", "-01", "-01:00", "UTC-1", "UTC-01:00": + seconds = -1 * 3600 + case "01:00", "+1", "+01", "+01:00", "UTC+1", "UTC+01:00": + seconds = 1 * 3600 + case "02:00", "+2", "+02", "+02:00", "UTC+2", "UTC+02:00": + seconds = 2 * 3600 + case "03:00", "+3", "+03", "+03:00", "UTC+3", "UTC+03:00": + seconds = 3 * 3600 + case "04:00", "+4", "+04", "+04:00", "UTC+4", "UTC+04:00": + seconds = 4 * 3600 + case "05:00", "+5", "+05", "+05:00", "UTC+5", "UTC+05:00": + seconds = 5 * 3600 + case "06:00", "+6", "+06", "+06:00", "UTC+6", "UTC+06:00": + seconds = 6 * 3600 + case "07:00", "+7", "+07", "+07:00", "UTC+7", "UTC+07:00": + seconds = 7 * 3600 + case "08:00", "+8", "+08", "+08:00", "UTC+8", "UTC+08:00": + seconds = 8 * 3600 + case "09:00", "+9", "+09", "+09:00", "UTC+9", "UTC+09:00": + seconds = 9 * 3600 + case "10:00", "+10", "+10:00", "UTC+10", "UTC+10:00": + seconds = 10 * 3600 + case "11:00", "+11", "+11:00", "UTC+11", "UTC+11:00": + seconds = 11 * 3600 + case "12:00", "+12", "+12:00", "UTC+12", "UTC+12:00": + seconds = 12 * 3600 + case "Z", "UTC", "UTC+0", "UTC-0", "UTC+00:00", "UTC-00:00": + seconds = 0 + default: + return 0, fmt.Errorf("invalid UTC offset") + } + + return seconds, nil +} + +func SanitizeDescription(s string) string { + s = SanitizeString(s) + + switch { + case s == "": + return "" + case UnwantedDescriptions[s]: + return "" + case strings.HasPrefix(s, "DCIM\\") && !strings.Contains(s, " "): + return "" + default: + return s + } +} diff --git a/pkg/mediameta/extractor.go b/pkg/mediameta/extractor.go new file mode 100644 index 00000000..e871189e --- /dev/null +++ b/pkg/mediameta/extractor.go @@ -0,0 +1,106 @@ +package mediameta + +import ( + "context" + "encoding/gob" + "errors" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/samber/lo" + "io" +) + +type ( + Extractor interface { + // Exts returns the supported file extensions. + Exts() []string + // Extract extracts the media meta from the given source. + Extract(ctx context.Context, ext string, source entitysource.EntitySource) ([]driver.MediaMeta, error) + } +) + +var ( + ErrFileTooLarge = errors.New("file too large") +) + +func init() { + gob.Register([]driver.MediaMeta{}) +} + +func NewExtractorManager(ctx context.Context, settings setting.Provider, l logging.Logger) Extractor { + e := &extractorManager{ + settings: settings, + extMap: make(map[string][]Extractor), + } + + extractors := []Extractor{} + + if e.settings.MediaMetaExifEnabled(ctx) { + exifE := newExifExtractor(settings, l) + extractors = append(extractors, exifE) + } + + if e.settings.MediaMetaMusicEnabled(ctx) { + musicE := newMusicExtractor(settings, l) + extractors = append(extractors, musicE) + } + + if e.settings.MediaMetaFFProbeEnabled(ctx) { + ffprobeE := newFFProbeExtractor(settings, l) + extractors = append(extractors, ffprobeE) + } + + for _, extractor := range extractors { + for _, ext := range extractor.Exts() { + if e.extMap[ext] == nil { + e.extMap[ext] = []Extractor{} + } + e.extMap[ext] = append(e.extMap[ext], extractor) + } + } + + return e +} + +type extractorManager struct { + settings setting.Provider + extMap map[string][]Extractor +} + +func (e *extractorManager) Exts() []string { + return lo.Keys(e.extMap) +} + +func (e *extractorManager) Extract(ctx context.Context, ext string, source entitysource.EntitySource) ([]driver.MediaMeta, error) { + if extractor, ok := e.extMap[ext]; ok { + res := []driver.MediaMeta{} + for _, e := range extractor { + _, _ = source.Seek(0, io.SeekStart) + data, err := e.Extract(ctx, ext, source) + if err != nil { + return nil, err + } + + res = append(res, data...) + } + + return res, nil + } else { + return nil, nil + } +} + +// checkFileSize checks if the file size exceeds the limit. +func checkFileSize(localLimit, remoteLimit int64, source entitysource.EntitySource) error { + if source.IsLocal() && localLimit > 0 && source.Entity().Size() > localLimit { + return ErrFileTooLarge + } + + if !source.IsLocal() && remoteLimit > 0 && source.Entity().Size() > remoteLimit { + return ErrFileTooLarge + } + + return nil +} diff --git a/pkg/mediameta/ffprobe.go b/pkg/mediameta/ffprobe.go new file mode 100644 index 00000000..3cc1117f --- /dev/null +++ b/pkg/mediameta/ffprobe.go @@ -0,0 +1,245 @@ +package mediameta + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +var ( + ffprobeExts = []string{ + "mp3", "m4a", "ogg", "flac", "3g2", "3gp", "asf", "asx", "avi", "divx", "flv", "m2ts", "m2v", "m4v", "mkv", "mov", "mp4", + "mpeg", "mpg", "mts", "mxf", "ogv", "rm", "swf", "webm", "wmv", + } +) + +type ( + FFProbeMeta struct { + Format *Format `json:"format"` + Streams []Stream `json:"streams"` + Chapters []Chapter `json:"chapters"` + } + + Stream struct { + Index int `json:"index"` + CodecName string `json:"codec_name"` + CodecLongName string `json:"codec_long_name"` + CodecType string `json:"codec_type"` + Width int `json:"width"` + Height int `json:"height"` + Duration string `json:"duration"` + Bitrate string `json:"bit_rate"` + } + Chapter struct { + Id int `json:"id"` + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Tags map[string]string `json:"tags"` + } + Format struct { + FormatName string `json:"format_name"` + FormatLongName string `json:"format_long_name"` + Duration string `json:"duration"` + Bitrate string `json:"bit_rate"` + Tags map[string]string `json:"tags"` + } +) + +const ( + UrlExpire = time.Duration(60) * time.Hour + StreamMediaFormat = "format" + StreamMediaFormatLong = "formatLong" + StreamMediaDuration = "duration" + StreamMediaBitrate = "bitrate" + StreamMediaStreamPrefix = "stream_" + StreamMediaChapterPrefix = "chapter_" + StreamMediaCodec = "codec" + StreamMediaCodecLongName = "codec_long_name" + StreamMediaWidth = "width" + StreamMediaHeight = "height" + StreamMediaStartTime = "start_time" + StreamMediaEndTime = "end_time" + StreamMediaChapterName = "name" + StreamMetaTitle = "title" + StreamMetaDescription = "description" +) + +func newFFProbeExtractor(settings setting.Provider, l logging.Logger) *ffprobeExtractor { + return &ffprobeExtractor{ + l: l, + settings: settings, + } +} + +type ffprobeExtractor struct { + settings setting.Provider + l logging.Logger +} + +func (f *ffprobeExtractor) Exts() []string { + return ffprobeExts +} + +func (f *ffprobeExtractor) Extract(ctx context.Context, ext string, source entitysource.EntitySource) ([]driver.MediaMeta, error) { + localLimit, remoteLimit := f.settings.MediaMetaFFProbeSizeLimit(ctx) + if err := checkFileSize(localLimit, remoteLimit, source); err != nil { + return nil, err + } + + var input string + if source.IsLocal() { + input = source.LocalPath(ctx) + } else { + expire := time.Now().Add(UrlExpire) + srcUrl, err := source.Url(driver.WithForcePublicEndpoint(ctx, false), entitysource.WithNoInternalProxy(), entitysource.WithExpire(&expire)) + if err != nil { + return nil, fmt.Errorf("failed to get entity url: %w", err) + } + input = srcUrl.Url + } + + cmd := exec.CommandContext(ctx, + f.settings.MediaMetaFFProbePath(ctx), + "-v", "quiet", + "-print_format", "json", + "-show_format", + "-show_streams", + "-show_chapters", + input, + ) + + res, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to invoke ffprobe: %w", err) + } + + f.l.Debug("ffprobe output: %s", res) + var meta FFProbeMeta + if err := json.Unmarshal(res, &meta); err != nil { + return nil, fmt.Errorf("failed to parse ffprobe output: %w", err) + } + + return ProbeMetaTransform(&meta), nil +} + +func ProbeMetaTransform(meta *FFProbeMeta) []driver.MediaMeta { + if meta.Format == nil { + return nil + } + + res := []driver.MediaMeta{} + if meta.Format.FormatName != "" { + res = append(res, driver.MediaMeta{ + Key: StreamMediaFormat, + Value: meta.Format.FormatName, + }) + } + if meta.Format.FormatLongName != "" { + res = append(res, driver.MediaMeta{ + Key: StreamMediaFormatLong, + Value: meta.Format.FormatLongName, + }) + } + if meta.Format.Duration != "" { + res = append(res, driver.MediaMeta{ + Key: StreamMediaDuration, + Value: meta.Format.Duration, + }) + } + if meta.Format.Bitrate != "" { + res = append(res, driver.MediaMeta{ + Key: StreamMediaBitrate, + Value: meta.Format.Bitrate, + }) + } + + for _, stream := range meta.Streams { + keyPrefix := fmt.Sprintf("%s%d_%s_", StreamMediaStreamPrefix, stream.Index, stream.CodecType) + if stream.CodecName != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaCodec, + Value: stream.CodecName, + }) + } + if stream.CodecLongName != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaCodecLongName, + Value: stream.CodecLongName, + }) + } + if stream.Width > 0 { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaWidth, + Value: strconv.Itoa(stream.Width), + }) + } + if stream.Height > 0 { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaHeight, + Value: strconv.Itoa(stream.Height), + }) + } + if stream.Duration != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaDuration, + Value: stream.Duration, + }) + } + if stream.Bitrate != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaBitrate, + Value: stream.Bitrate, + }) + } + } + + for _, chapter := range meta.Chapters { + keyPrefix := fmt.Sprintf("%s%d_", StreamMediaChapterPrefix, chapter.Id) + if chapter.StartTime != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaStartTime, + Value: chapter.StartTime, + }) + } + if chapter.EndTime != "" { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaEndTime, + Value: chapter.EndTime, + }) + } + if title, ok := chapter.Tags["title"]; ok { + res = append(res, driver.MediaMeta{ + Key: keyPrefix + StreamMediaChapterName, + Value: title, + }) + } + } + + if title, ok := meta.Format.Tags["title"]; ok { + res = append(res, driver.MediaMeta{ + Key: StreamMetaTitle, + Value: title, + }) + } + + if description, ok := meta.Format.Tags["description"]; ok { + res = append(res, driver.MediaMeta{ + Key: StreamMetaDescription, + Value: description[0:min(len(description), 255)], + }) + } + + for i := 0; i < len(res); i++ { + res[i].Type = driver.MetaTypeStreamMedia + } + + return res +} diff --git a/pkg/mediameta/music.go b/pkg/mediameta/music.go new file mode 100644 index 00000000..ac2ec78c --- /dev/null +++ b/pkg/mediameta/music.go @@ -0,0 +1,145 @@ +package mediameta + +import ( + "context" + "errors" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/dhowden/tag" +) + +var ( + audioExts = []string{ + "mp3", "m4a", "ogg", "flac", + } +) + +const ( + MusicFormat = "format" + MusicFileType = "file_type" + MusicTitle = "title" + MusicAlbum = "album" + MusicArtist = "artist" + MusicAlbumArtists = "album_artists" + MusicComposer = "composer" + MusicGenre = "genre" + MusicYear = "year" + MusicTrack = "track" + MusicDisc = "disc" +) + +func newMusicExtractor(settings setting.Provider, l logging.Logger) *musicExtractor { + return &musicExtractor{ + l: l, + settings: settings, + } +} + +type musicExtractor struct { + l logging.Logger + settings setting.Provider +} + +func (a *musicExtractor) Exts() []string { + return audioExts +} + +func (a *musicExtractor) Extract(ctx context.Context, ext string, source entitysource.EntitySource) ([]driver.MediaMeta, error) { + localLimit, remoteLimit := a.settings.MediaMetaMusicSizeLimit(ctx) + if err := checkFileSize(localLimit, remoteLimit, source); err != nil { + return nil, err + } + + m, err := tag.ReadFrom(source) + if err != nil { + if errors.Is(err, tag.ErrNoTagsFound) { + a.l.Debug("No tags found in file.") + return nil, nil + } + return nil, fmt.Errorf("failed to read tags from file: %w", err) + } + + metas := []driver.MediaMeta{ + { + Key: MusicFormat, + Value: string(m.Format()), + }, + { + Key: MusicFileType, + Value: string(m.FileType()), + }, + } + + if title := m.Title(); title != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicTitle, + Value: title, + }) + } + + if album := m.Album(); album != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicAlbum, + Value: album, + }) + } + + if artist := m.Artist(); artist != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicArtist, + Value: artist, + }) + } + + if albumArtists := m.AlbumArtist(); albumArtists != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicAlbumArtists, + Value: albumArtists, + }) + } + + if composer := m.Composer(); composer != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicComposer, + Value: composer, + }) + } + + if genre := m.Genre(); genre != "" { + metas = append(metas, driver.MediaMeta{ + Key: MusicGenre, + Value: genre, + }) + } + + if year := m.Year(); year != 0 { + metas = append(metas, driver.MediaMeta{ + Key: MusicYear, + Value: fmt.Sprintf("%d", year), + }) + } + + if track, total := m.Track(); track != 0 { + metas = append(metas, driver.MediaMeta{ + Key: MusicTrack, + Value: fmt.Sprintf("%d/%d", track, total), + }) + } + + if disc, total := m.Disc(); disc != 0 { + metas = append(metas, driver.MediaMeta{ + Key: MusicDisc, + Value: fmt.Sprintf("%d/%d", disc, total), + }) + } + + for i := 0; i < len(metas); i++ { + metas[i].Type = driver.MediaTypeMusic + } + + return metas, nil +} diff --git a/pkg/mocks/cachemock/mock.go b/pkg/mocks/cachemock/mock.go deleted file mode 100644 index 921b1cd1..00000000 --- a/pkg/mocks/cachemock/mock.go +++ /dev/null @@ -1,37 +0,0 @@ -package cachemock - -import "github.com/stretchr/testify/mock" - -type CacheClientMock struct { - mock.Mock -} - -func (c CacheClientMock) Set(key string, value interface{}, ttl int) error { - return c.Called(key, value, ttl).Error(0) -} - -func (c CacheClientMock) Get(key string) (interface{}, bool) { - args := c.Called(key) - return args.Get(0), args.Bool(1) -} - -func (c CacheClientMock) Gets(keys []string, prefix string) (map[string]interface{}, []string) { - args := c.Called(keys, prefix) - return args.Get(0).(map[string]interface{}), args.Get(1).([]string) -} - -func (c CacheClientMock) Sets(values map[string]interface{}, prefix string) error { - return c.Called(values).Error(0) -} - -func (c CacheClientMock) Delete(keys []string, prefix string) error { - return c.Called(keys, prefix).Error(0) -} - -func (c CacheClientMock) Persist(path string) error { - return c.Called(path).Error(0) -} - -func (c CacheClientMock) Restore(path string) error { - return c.Called(path).Error(0) -} diff --git a/pkg/mocks/controllermock/c.go b/pkg/mocks/controllermock/c.go deleted file mode 100644 index 6a77793a..00000000 --- a/pkg/mocks/controllermock/c.go +++ /dev/null @@ -1,43 +0,0 @@ -package controllermock - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/mock" -) - -type SlaveControllerMock struct { - mock.Mock -} - -func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) { - args := s.Called(pingReq) - return args.Get(0).(serializer.NodePingResp), args.Error(1) -} - -func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) { - args := s.Called(s2) - return args.Get(0).(common.Aria2), args.Error(1) -} - -func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error { - args := s.Called(s3, s2, message) - return args.Error(0) -} - -func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error { - args := s.Called(s3, i, s2, f) - return args.Error(0) -} - -func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) { - args := s.Called(s2) - return args.Get(0).(*cluster.MasterInfo), args.Error(1) -} - -func (s SlaveControllerMock) GetPolicyOauthToken(s2 string, u uint) (string, error) { - args := s.Called(s2, u) - return args.String(0), args.Error(1) -} diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go deleted file mode 100644 index 01c450b8..00000000 --- a/pkg/mocks/mocks.go +++ /dev/null @@ -1,151 +0,0 @@ -package mocks - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - testMock "github.com/stretchr/testify/mock" -) - -type NodePoolMock struct { - testMock.Mock -} - -func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) { - args := n.Called(feature, lb) - return args.Error(0), args.Get(1).(cluster.Node) -} - -func (n NodePoolMock) GetNodeByID(id uint) cluster.Node { - args := n.Called(id) - if res, ok := args.Get(0).(cluster.Node); ok { - return res - } - - return nil -} - -func (n NodePoolMock) Add(node *model.Node) { - n.Called(node) -} - -func (n NodePoolMock) Delete(id uint) { - n.Called(id) -} - -type NodeMock struct { - testMock.Mock -} - -func (n NodeMock) Init(node *model.Node) { - n.Called(node) -} - -func (n NodeMock) IsFeatureEnabled(feature string) bool { - args := n.Called(feature) - return args.Bool(0) -} - -func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) { - n.Called(callback) -} - -func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { - args := n.Called(req) - return args.Get(0).(*serializer.NodePingResp), args.Error(1) -} - -func (n NodeMock) IsActive() bool { - args := n.Called() - return args.Bool(0) -} - -func (n NodeMock) GetAria2Instance() common.Aria2 { - args := n.Called() - return args.Get(0).(common.Aria2) -} - -func (n NodeMock) ID() uint { - args := n.Called() - return args.Get(0).(uint) -} - -func (n NodeMock) Kill() { - n.Called() -} - -func (n NodeMock) IsMater() bool { - args := n.Called() - return args.Bool(0) -} - -func (n NodeMock) MasterAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n NodeMock) SlaveAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n NodeMock) DBModel() *model.Node { - args := n.Called() - return args.Get(0).(*model.Node) -} - -type Aria2Mock struct { - testMock.Mock -} - -func (a Aria2Mock) Init() error { - args := a.Called() - return args.Error(0) -} - -func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { - args := a.Called(task, options) - return args.String(0), args.Error(1) -} - -func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) { - args := a.Called(task) - return args.Get(0).(rpc.StatusInfo), args.Error(1) -} - -func (a Aria2Mock) Cancel(task *model.Download) error { - args := a.Called(task) - return args.Error(0) -} - -func (a Aria2Mock) Select(task *model.Download, files []int) error { - args := a.Called(task, files) - return args.Error(0) -} - -func (a Aria2Mock) GetConfig() model.Aria2Option { - args := a.Called() - return args.Get(0).(model.Aria2Option) -} - -func (a Aria2Mock) DeleteTempFile(download *model.Download) error { - args := a.Called(download) - return args.Error(0) -} - -type TaskPoolMock struct { - testMock.Mock -} - -func (t TaskPoolMock) Add(num int) { - t.Called(num) -} - -func (t TaskPoolMock) Submit(job task.Job) { - t.Called(job) -} diff --git a/pkg/mocks/remoteclientmock/mock.go b/pkg/mocks/remoteclientmock/mock.go deleted file mode 100644 index 303b6737..00000000 --- a/pkg/mocks/remoteclientmock/mock.go +++ /dev/null @@ -1,32 +0,0 @@ -package remoteclientmock - -import ( - "context" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/mock" -) - -type RemoteClientMock struct { - mock.Mock -} - -func (r *RemoteClientMock) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error { - return r.Called(ctx, session, ttl, overwrite).Error(0) -} - -func (r *RemoteClientMock) GetUploadURL(ttl int64, sessionID string) (string, string, error) { - args := r.Called(ttl, sessionID) - - return args.String(0), args.String(1), args.Error(2) -} - -func (r *RemoteClientMock) Upload(ctx context.Context, file fsctx.FileHeader) error { - args := r.Called(ctx, file) - return args.Error(0) -} - -func (r *RemoteClientMock) DeleteUploadSession(ctx context.Context, sessionID string) error { - args := r.Called(ctx, sessionID) - return args.Error(0) -} diff --git a/pkg/mocks/requestmock/request.go b/pkg/mocks/requestmock/request.go deleted file mode 100644 index 7e6ca1b1..00000000 --- a/pkg/mocks/requestmock/request.go +++ /dev/null @@ -1,15 +0,0 @@ -package requestmock - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/mock" - "io" -) - -type RequestMock struct { - mock.Mock -} - -func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - return r.Called(method, target, body, opts).Get(0).(*request.Response) -} diff --git a/pkg/mocks/thumbmock/thumb.go b/pkg/mocks/thumbmock/thumb.go deleted file mode 100644 index 553ba50e..00000000 --- a/pkg/mocks/thumbmock/thumb.go +++ /dev/null @@ -1,25 +0,0 @@ -package thumbmock - -import ( - "context" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" - "github.com/stretchr/testify/mock" - "io" -) - -type GeneratorMock struct { - mock.Mock -} - -func (g GeneratorMock) Generate(ctx context.Context, file io.Reader, src string, name string, options map[string]string) (*thumb.Result, error) { - res := g.Called(ctx, file, src, name, options) - return res.Get(0).(*thumb.Result), res.Error(1) -} - -func (g GeneratorMock) Priority() int { - return 0 -} - -func (g GeneratorMock) EnableFlag() string { - return "thumb_vips_enabled" -} diff --git a/pkg/mocks/wopimock/mock.go b/pkg/mocks/wopimock/mock.go deleted file mode 100644 index 0573c047..00000000 --- a/pkg/mocks/wopimock/mock.go +++ /dev/null @@ -1,21 +0,0 @@ -package wopimock - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/stretchr/testify/mock" -) - -type WopiClientMock struct { - mock.Mock -} - -func (w *WopiClientMock) NewSession(user uint, file *model.File, action wopi.ActonType) (*wopi.Session, error) { - args := w.Called(user, file, action) - return args.Get(0).(*wopi.Session), args.Error(1) -} - -func (w *WopiClientMock) AvailableExts() []string { - args := w.Called() - return args.Get(0).([]string) -} diff --git a/pkg/mq/mq.go b/pkg/mq/mq.go deleted file mode 100644 index e7a8a344..00000000 --- a/pkg/mq/mq.go +++ /dev/null @@ -1,160 +0,0 @@ -package mq - -import ( - "encoding/gob" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "strconv" - "sync" - "time" -) - -// Message 消息事件正文 -type Message struct { - // 消息触发者 - TriggeredBy string - - // 事件标识 - Event string - - // 消息正文 - Content interface{} -} - -type CallbackFunc func(Message) - -// MQ 消息队列 -type MQ interface { - rpc.Notifier - - // 发布一个消息 - Publish(string, Message) - - // 订阅一个消息主题 - Subscribe(string, int) <-chan Message - - // 订阅一个消息主题,注册触发回调函数 - SubscribeCallback(string, CallbackFunc) - - // 取消订阅一个消息主题 - Unsubscribe(string, <-chan Message) -} - -var GlobalMQ = NewMQ() - -func NewMQ() MQ { - return &inMemoryMQ{ - topics: make(map[string][]chan Message), - callbacks: make(map[string][]CallbackFunc), - } -} - -func init() { - gob.Register(Message{}) - gob.Register([]rpc.Event{}) -} - -type inMemoryMQ struct { - topics map[string][]chan Message - callbacks map[string][]CallbackFunc - sync.RWMutex -} - -func (i *inMemoryMQ) Publish(topic string, message Message) { - i.RLock() - subscribersChan, okChan := i.topics[topic] - subscribersCallback, okCallback := i.callbacks[topic] - i.RUnlock() - - if okChan { - go func(subscribersChan []chan Message) { - for i := 0; i < len(subscribersChan); i++ { - select { - case subscribersChan[i] <- message: - case <-time.After(time.Millisecond * 500): - } - } - }(subscribersChan) - - } - - if okCallback { - for i := 0; i < len(subscribersCallback); i++ { - go subscribersCallback[i](message) - } - } -} - -func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message { - ch := make(chan Message, buffer) - i.Lock() - i.topics[topic] = append(i.topics[topic], ch) - i.Unlock() - return ch -} - -func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) { - i.Lock() - i.callbacks[topic] = append(i.callbacks[topic], callbackFunc) - i.Unlock() -} - -func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) { - i.Lock() - defer i.Unlock() - - subscribers, ok := i.topics[topic] - if !ok { - return - } - - var newSubs []chan Message - for _, subscriber := range subscribers { - if subscriber == sub { - continue - } - newSubs = append(newSubs, subscriber) - } - - i.topics[topic] = newSubs -} - -func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) { - for _, event := range events { - i.Publish(event.Gid, Message{ - TriggeredBy: event.Gid, - Event: strconv.FormatInt(int64(status), 10), - Content: events, - }) - } -} - -// OnDownloadStart 下载开始 -func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) { - i.Aria2Notify(events, common.Downloading) -} - -// OnDownloadPause 下载暂停 -func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) { - i.Aria2Notify(events, common.Paused) -} - -// OnDownloadStop 下载停止 -func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) { - i.Aria2Notify(events, common.Canceled) -} - -// OnDownloadComplete 下载完成 -func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) { - i.Aria2Notify(events, common.Complete) -} - -// OnDownloadError 下载出错 -func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) { - i.Aria2Notify(events, common.Error) -} - -// OnBtDownloadComplete BT下载完成 -func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) { - i.Aria2Notify(events, common.Complete) -} diff --git a/pkg/mq/mq_test.go b/pkg/mq/mq_test.go deleted file mode 100644 index 9acdd3f6..00000000 --- a/pkg/mq/mq_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package mq - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/stretchr/testify/assert" - "sync" - "testing" - "time" -) - -func TestPublishAndSubscribe(t *testing.T) { - t.Parallel() - asserts := assert.New(t) - mq := NewMQ() - - // No subscriber - { - asserts.NotPanics(func() { - mq.Publish("No subscriber", Message{}) - }) - } - - // One channel subscriber - { - topic := "One channel subscriber" - msg := Message{TriggeredBy: "Tester"} - notifier := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber - { - topic := "two channel subscriber" - msg := Message{TriggeredBy: "Tester"} - notifier := mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - wg.Done() - msgRecv := <-notifier - asserts.Equal(msg, msgRecv) - }() - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber, one timeout - { - topic := "two channel subscriber, one timeout" - msg := Message{TriggeredBy: "Tester"} - mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber, one unsubscribe - { - topic := "two channel subscriber, one unsubscribe" - msg := Message{TriggeredBy: "Tester"} - mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - notifier := mq.Subscribe(topic, 0) - mq.Unsubscribe(topic, notifier) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - - select { - case <-notifier: - t.Error() - default: - } - } -} - -func TestAria2Interface(t *testing.T) { - t.Parallel() - asserts := assert.New(t) - mq := NewMQ() - var ( - OnDownloadStart int - OnDownloadPause int - OnDownloadStop int - OnDownloadComplete int - OnDownloadError int - ) - l := sync.Mutex{} - - mq.SubscribeCallback("TestAria2Interface", func(message Message) { - asserts.Equal("TestAria2Interface", message.TriggeredBy) - l.Lock() - defer l.Unlock() - switch message.Event { - case "1": - OnDownloadStart++ - case "2": - OnDownloadPause++ - case "5": - OnDownloadStop++ - case "4": - OnDownloadComplete++ - case "3": - OnDownloadError++ - } - }) - - mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - - time.Sleep(time.Duration(500) * time.Millisecond) - - asserts.Equal(2, OnDownloadStart) - asserts.Equal(2, OnDownloadPause) - asserts.Equal(2, OnDownloadStop) - asserts.Equal(4, OnDownloadComplete) - asserts.Equal(2, OnDownloadError) -} diff --git a/pkg/queue/metric.go b/pkg/queue/metric.go new file mode 100644 index 00000000..bb3413ad --- /dev/null +++ b/pkg/queue/metric.go @@ -0,0 +1,79 @@ +package queue + +import "sync/atomic" + +// Metric interface +type Metric interface { + IncBusyWorker() + DecBusyWorker() + BusyWorkers() uint64 + SuccessTasks() uint64 + FailureTasks() uint64 + SubmittedTasks() uint64 + IncSuccessTask() + IncFailureTask() + IncSubmittedTask() +} + +var _ Metric = (*metric)(nil) + +type metric struct { + busyWorkers uint64 + successTasks uint64 + failureTasks uint64 + submittedTasks uint64 + suspendingTasks uint64 +} + +// NewMetric for default metric structure +func NewMetric() Metric { + return &metric{} +} + +func (m *metric) IncBusyWorker() { + atomic.AddUint64(&m.busyWorkers, 1) +} + +func (m *metric) DecBusyWorker() { + atomic.AddUint64(&m.busyWorkers, ^uint64(0)) +} + +func (m *metric) BusyWorkers() uint64 { + return atomic.LoadUint64(&m.busyWorkers) +} + +func (m *metric) IncSuccessTask() { + atomic.AddUint64(&m.successTasks, 1) +} + +func (m *metric) IncFailureTask() { + atomic.AddUint64(&m.failureTasks, 1) +} + +func (m *metric) IncSubmittedTask() { + atomic.AddUint64(&m.submittedTasks, 1) +} + +func (m *metric) SuccessTasks() uint64 { + return atomic.LoadUint64(&m.successTasks) +} + +func (m *metric) FailureTasks() uint64 { + return atomic.LoadUint64(&m.failureTasks) +} + +func (m *metric) SubmittedTasks() uint64 { + return atomic.LoadUint64(&m.submittedTasks) +} + +func (m *metric) SuspendingTasks() uint64 { + return atomic.LoadUint64(&m.suspendingTasks) +} + +func (m *metric) IncSuspendingTask() { + atomic.AddUint64(&m.suspendingTasks, 1) +} + +func (m *metric) DecSuspendingTask() { + atomic.AddUint64(&m.suspendingTasks, ^uint64(0)) +} diff --git a/pkg/queue/options.go b/pkg/queue/options.go new file mode 100644 index 00000000..40c6a498 --- /dev/null +++ b/pkg/queue/options.go @@ -0,0 +1,109 @@ +package queue + +import ( + "runtime" + "time" +) + +// An Option configures a mutex. +type Option interface { + apply(*options) +} + +// OptionFunc is a function that configures a queue. +type OptionFunc func(*options) + +// Apply calls f(option) +func (f OptionFunc) apply(option *options) { + f(option) +} + +type options struct { + maxTaskExecution time.Duration // Maximum execution time for a Task. + retryDelay time.Duration + taskPullInterval time.Duration + backoffFactor float64 + backoffMaxDuration time.Duration + maxRetry int + resumeTaskType []string + workerCount int + name string +} + +func newDefaultOptions() *options { + return &options{ + workerCount: runtime.NumCPU(), + maxTaskExecution: 60 * time.Hour, + backoffFactor: 2, + backoffMaxDuration: 60 * time.Second, + resumeTaskType: []string{}, + taskPullInterval: 1 * time.Second, + name: "default", + } +} + +// WithMaxTaskExecution set maximum execution time for a Task. +func WithMaxTaskExecution(d time.Duration) Option { + return OptionFunc(func(q *options) { + q.maxTaskExecution = d + }) +} + +// WithRetryDelay set retry delay +func WithRetryDelay(d time.Duration) Option { + return OptionFunc(func(q *options) { + q.retryDelay = d + }) +} + +// WithBackoffFactor set backoff factor +func WithBackoffFactor(f float64) Option { + return OptionFunc(func(q *options) { + q.backoffFactor = f + }) +} + +// WithBackoffMaxDuration set backoff max duration +func WithBackoffMaxDuration(d time.Duration) Option { + return OptionFunc(func(q *options) { + q.backoffMaxDuration = d + }) +} + +// WithMaxRetry set max retry +func WithMaxRetry(n int) Option { + return OptionFunc(func(q *options) { + q.maxRetry = n + }) +} + +// WithResumeTaskType set resume Task type +func WithResumeTaskType(types ...string) Option { + return OptionFunc(func(q *options) { + q.resumeTaskType = types + }) +} + +// WithWorkerCount set worker count +func WithWorkerCount(num int) Option { + return OptionFunc(func(q *options) { + if num <= 0 { + num = runtime.NumCPU() + } + q.workerCount = num + }) +} + +// WithName set queue name +func WithName(name string) Option { + return OptionFunc(func(q *options) { + q.name = name + }) +} + +// WithTaskPullInterval set task pull interval +func WithTaskPullInterval(d time.Duration) Option { + return OptionFunc(func(q *options) { + q.taskPullInterval = d + }) +} diff --git a/pkg/queue/queue.go b/pkg/queue/queue.go new file mode 100644 index 00000000..2a28d838 --- /dev/null +++ b/pkg/queue/queue.go @@ -0,0 +1,437 @@ +package queue + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/jpillora/backoff" +) + +type ( + Queue interface { + // Start resume tasks and starts all workers. + Start() + // Shutdown stops all workers. + Shutdown() + // SubmitTask submits a Task to the queue. + QueueTask(ctx context.Context, t Task) error + // BusyWorkers returns the numbers of workers in the running process. + BusyWorkers() int + // BusyWorkers returns the numbers of success tasks. + SuccessTasks() int + // FailureTasks returns the numbers of failure tasks. + FailureTasks() int + // SubmittedTasks returns the numbers of submitted tasks. + SubmittedTasks() int + // SuspendingTasks returns the numbers of suspending tasks. + SuspendingTasks() int + } + queue struct { + sync.Mutex + routineGroup *routineGroup + metric *metric + quit chan struct{} + ready chan struct{} + scheduler Scheduler + stopOnce sync.Once + stopFlag int32 + rootCtx context.Context + cancel context.CancelFunc + + // Dependencies + logger logging.Logger + taskClient inventory.TaskClient + dep Dep + registry TaskRegistry + + // Options + *options + } + + Dep interface { + ForkWithLogger(ctx context.Context, l logging.Logger) context.Context + } +) + +var ( + CriticalErr = errors.New("non-retryable error") +) + +func New(l logging.Logger, taskClient inventory.TaskClient, registry TaskRegistry, dep Dep, opts ...Option) Queue { + o := newDefaultOptions() + for _, opt := range opts { + opt.apply(o) + } + + ctx, cancel := context.WithCancel(context.Background()) + + return &queue{ + routineGroup: newRoutineGroup(), + scheduler: NewFifoScheduler(0, l), + quit: make(chan struct{}), + ready: make(chan struct{}, 1), + metric: &metric{}, + options: o, + logger: l, + registry: registry, + taskClient: taskClient, + dep: dep, + rootCtx: ctx, + cancel: cancel, + } +} + +// Start to enable all worker +func (q *queue) Start() { + q.routineGroup.Run(func() { + // Resume tasks in DB + if len(q.options.resumeTaskType) > 0 && q.taskClient != nil { + + ctx := context.TODO() + ctx = context.WithValue(ctx, inventory.LoadTaskUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + tasks, err := q.taskClient.GetPendingTasks(ctx, q.resumeTaskType...) + if err != nil { + q.logger.Warning("Failed to get pending tasks from DB for given type %v: %s", q.resumeTaskType, err) + } + + resumed := 0 + for _, t := range tasks { + resumedTask, err := NewTaskFromModel(t) + if err != nil { + q.logger.Warning("Failed to resume task %d: %s", t.ID, err) + continue + } + + if resumedTask.Status() == task.StatusSuspending { + q.metric.IncSuspendingTask() + q.metric.IncSubmittedTask() + } + + if err := q.QueueTask(ctx, resumedTask); err != nil { + q.logger.Warning("Failed to resume task %d: %s", t.ID, err) + } + resumed++ + } + + q.logger.Info("Resumed %d tasks from DB.", resumed) + } + + q.start() + }) + q.logger.Info("Queue %q started with %d workers.", q.name, q.workerCount) +} + +// Shutdown stops all queues. +func (q *queue) Shutdown() { + q.logger.Info("Shutting down queue %q...", q.name) + defer func() { + q.routineGroup.Wait() + }() + + if !atomic.CompareAndSwapInt32(&q.stopFlag, 0, 1) { + return + } + + q.stopOnce.Do(func() { + q.cancel() + if q.metric.BusyWorkers() > 0 { + q.logger.Info("shutdown all tasks in queue %q: %d workers", q.name, q.metric.BusyWorkers()) + } + + if err := q.scheduler.Shutdown(); err != nil { + q.logger.Error("failed to shutdown scheduler in queue %q: %w", q.name, err) + } + close(q.quit) + }) + +} + +// BusyWorkers returns the numbers of workers in the running process. +func (q *queue) BusyWorkers() int { + return int(q.metric.BusyWorkers()) +} + +// BusyWorkers returns the numbers of success tasks. +func (q *queue) SuccessTasks() int { + return int(q.metric.SuccessTasks()) +} + +// BusyWorkers returns the numbers of failure tasks. +func (q *queue) FailureTasks() int { + return int(q.metric.FailureTasks()) +} + +// BusyWorkers returns the numbers of submitted tasks. +func (q *queue) SubmittedTasks() int { + return int(q.metric.SubmittedTasks()) +} + +// SuspendingTasks returns the numbers of suspending tasks. +func (q *queue) SuspendingTasks() int { + return int(q.metric.SuspendingTasks()) +} + +// QueueTask to queue single Task +func (q *queue) QueueTask(ctx context.Context, t Task) error { + if atomic.LoadInt32(&q.stopFlag) == 1 { + return ErrQueueShutdown + } + + if t.Status() != task.StatusSuspending { + q.metric.IncSubmittedTask() + if err := q.transitStatus(ctx, t, task.StatusQueued); err != nil { + return err + } + } + + if err := q.scheduler.Queue(t); err != nil { + return err + } + owner := "" + if t.Owner() != nil { + owner = t.Owner().Email + } + q.logger.Info("New Task with type %q submitted to queue %q by %q", t.Type(), q.name, owner) + if q.registry != nil { + q.registry.Set(t.ID(), t) + } + + return nil +} + +// newContext creates a new context for a new Task iteration. +func (q *queue) newContext(t Task) context.Context { + l := q.logger.CopyWithPrefix(fmt.Sprintf("[Cid: %s TaskID: %d Queue: %s]", t.CorrelationID(), t.ID(), q.name)) + ctx := q.dep.ForkWithLogger(q.rootCtx, l) + ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, t.CorrelationID()) + ctx = context.WithValue(ctx, logging.LoggerCtx{}, l) + ctx = context.WithValue(ctx, inventory.UserCtx{}, t.Owner()) + return ctx +} + +func (q *queue) work(t Task) { + ctx := q.newContext(t) + l := logging.FromContext(ctx) + timeIterationStart := time.Now() + + var err error + // to handle panic cases from inside the worker + // in such case, we start a new goroutine + defer func() { + q.metric.DecBusyWorker() + e := recover() + if e != nil { + l.Error("Panic error in queue %q: %v", q.name, e) + t.OnError(fmt.Errorf("panic error: %v", e), time.Since(timeIterationStart)) + + _ = q.transitStatus(ctx, t, task.StatusError) + } + q.schedule() + }() + + err = q.transitStatus(ctx, t, task.StatusProcessing) + if err != nil { + l.Error("failed to transit task %d to processing: %s", t.ID(), err.Error()) + panic(err) + } + + for { + timeIterationStart = time.Now() + var next task.Status + next, err = q.run(ctx, t) + if err != nil { + t.OnError(err, time.Since(timeIterationStart)) + l.Error("runtime error in queue %q: %s", q.name, err.Error()) + + _ = q.transitStatus(ctx, t, task.StatusError) + break + } + + // iteration completes + t.OnIterationComplete(time.Since(timeIterationStart)) + _ = q.transitStatus(ctx, t, next) + if next != task.StatusProcessing { + break + } + } +} + +func (q *queue) run(ctx context.Context, t Task) (task.Status, error) { + l := logging.FromContext(ctx) + + // create channel with buffer size 1 to avoid goroutine leak + done := make(chan struct { + err error + next task.Status + }, 1) + panicChan := make(chan interface{}, 1) + startTime := time.Now() + ctx, cancel := context.WithTimeout(ctx, q.maxTaskExecution-t.Executed()) + defer func() { + cancel() + }() + + // run the job + go func() { + // handle panic issue + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + l.Debug("Iteration started.") + next, err := t.Do(ctx) + l.Debug("Iteration ended with err=%s", err) + if err != nil && q.maxRetry-t.Retried() > 0 && !errors.Is(err, CriticalErr) && atomic.LoadInt32(&q.stopFlag) != 1 { + // Retry needed + t.OnRetry(err) + b := &backoff.Backoff{ + Max: q.backoffMaxDuration, + Factor: q.backoffFactor, + } + delay := q.retryDelay + if q.retryDelay == 0 { + delay = b.ForAttempt(float64(t.Retried())) + } + + // Resume after to retry + l.Info("Will be retried in %s", delay) + t.OnSuspend(time.Now().Add(delay).Unix()) + err = nil + next = task.StatusSuspending + } + + done <- struct { + err error + next task.Status + }{err: err, next: next} + }() + + select { + case p := <-panicChan: + panic(p) + case <-ctx.Done(): // timeout reached + return task.StatusError, ctx.Err() + case <-q.quit: // shutdown service + // cancel job + cancel() + + leftTime := q.maxTaskExecution - t.Executed() - time.Since(startTime) + // wait job + select { + case <-time.After(leftTime): + return task.StatusError, context.DeadlineExceeded + case r := <-done: // job finish + return r.next, r.err + case p := <-panicChan: + panic(p) + } + case r := <-done: // job finish + return r.next, r.err + } +} + +// beforeTaskStart updates Task status from queued to processing +func (q *queue) transitStatus(ctx context.Context, task Task, to task.Status) (err error) { + old := task.Status() + transition, ok := stateTransitions[task.Status()][to] + if !ok { + err = fmt.Errorf("invalid state transition from %s to %s", old, to) + } else { + if innerErr := transition(ctx, task, to, q); innerErr != nil { + err = fmt.Errorf("failed to transit Task status from %s to %s: %w", old, to, innerErr) + } + } + + l := logging.FromContext(ctx) + if err != nil { + l.Error(err.Error()) + } + + l.Info("Task %d status changed from %q to %q.", task.ID(), old, to) + return +} + +// schedule to check worker number +func (q *queue) schedule() { + q.Lock() + defer q.Unlock() + if q.BusyWorkers() >= q.workerCount { + return + } + + select { + case q.ready <- struct{}{}: + default: + } +} + +// start to start all worker +func (q *queue) start() { + tasks := make(chan Task, 1) + + for { + // check worker number + q.schedule() + + select { + // wait worker ready + case <-q.ready: + case <-q.quit: + return + } + + // request Task from queue in background + q.routineGroup.Run(func() { + for { + t, err := q.scheduler.Request() + if t == nil || err != nil { + if err != nil { + select { + case <-q.quit: + if !errors.Is(err, ErrNoTaskInQueue) { + close(tasks) + return + } + case <-time.After(q.taskPullInterval): + // sleep to fetch new Task + } + } + } + if t != nil { + tasks <- t + return + } + + select { + case <-q.quit: + if !errors.Is(err, ErrNoTaskInQueue) { + close(tasks) + return + } + default: + } + } + }) + + t, ok := <-tasks + if !ok { + return + } + + // start new Task + q.metric.IncBusyWorker() + q.routineGroup.Run(func() { + q.work(t) + }) + } +} diff --git a/pkg/queue/registry.go b/pkg/queue/registry.go new file mode 100644 index 00000000..99da4593 --- /dev/null +++ b/pkg/queue/registry.go @@ -0,0 +1,60 @@ +package queue + +import "sync" + +type ( + // TaskRegistry is used in slave node to track in-memory stateful tasks. + TaskRegistry interface { + // NextID returns the next available Task ID. + NextID() int + // Get returns the Task by ID. + Get(id int) (Task, bool) + // Set sets the Task by ID. + Set(id int, t Task) + // Delete deletes the Task by ID. + Delete(id int) + } + + taskRegistry struct { + tasks map[int]Task + current int + mu sync.Mutex + } +) + +// NewTaskRegistry creates a new TaskRegistry. +func NewTaskRegistry() TaskRegistry { + return &taskRegistry{ + tasks: make(map[int]Task), + } +} + +func (r *taskRegistry) NextID() int { + r.mu.Lock() + defer r.mu.Unlock() + + r.current++ + return r.current +} + +func (r *taskRegistry) Get(id int) (Task, bool) { + r.mu.Lock() + defer r.mu.Unlock() + + t, ok := r.tasks[id] + return t, ok +} + +func (r *taskRegistry) Set(id int, t Task) { + r.mu.Lock() + defer r.mu.Unlock() + + r.tasks[id] = t +} + +func (r *taskRegistry) Delete(id int) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.tasks, id) +} diff --git a/pkg/queue/scheduler.go b/pkg/queue/scheduler.go new file mode 100644 index 00000000..a0c47543 --- /dev/null +++ b/pkg/queue/scheduler.go @@ -0,0 +1,124 @@ +package queue + +import ( + "errors" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "sync" + "sync/atomic" + "time" +) + +var ( + // ErrQueueShutdown the queue is released and closed. + ErrQueueShutdown = errors.New("queue has been closed and released") + // ErrMaxCapacity Maximum size limit reached + ErrMaxCapacity = errors.New("golang-queue: maximum size limit reached") + // ErrNoTaskInQueue there is nothing in the queue + ErrNoTaskInQueue = errors.New("golang-queue: no Task in queue") +) + +type ( + Scheduler interface { + // Queue add a new Task into the queue + Queue(task Task) error + // Request get a new Task from the queue + Request() (Task, error) + // Shutdown stop all worker + Shutdown() error + } + fifoScheduler struct { + sync.Mutex + taskQueue taskHeap + capacity int + count int + exit chan struct{} + logger logging.Logger + stopOnce sync.Once + stopFlag int32 + } + taskHeap []Task +) + +// Queue send Task to the buffer channel +func (s *fifoScheduler) Queue(task Task) error { + if atomic.LoadInt32(&s.stopFlag) == 1 { + return ErrQueueShutdown + } + if s.capacity > 0 && s.count >= s.capacity { + return ErrMaxCapacity + } + + s.Lock() + s.taskQueue.Push(task) + s.count++ + s.Unlock() + + return nil +} + +// Request a new Task from channel +func (s *fifoScheduler) Request() (Task, error) { + if atomic.LoadInt32(&s.stopFlag) == 1 { + return nil, ErrQueueShutdown + } + + if s.count == 0 { + return nil, ErrNoTaskInQueue + } + s.Lock() + if s.taskQueue[s.taskQueue.Len()-1].ResumeTime() > time.Now().Unix() { + s.Unlock() + return nil, ErrNoTaskInQueue + } + + data := s.taskQueue.Pop() + s.count-- + s.Unlock() + + return data.(Task), nil +} + +// Shutdown the worker +func (s *fifoScheduler) Shutdown() error { + if !atomic.CompareAndSwapInt32(&s.stopFlag, 0, 1) { + return ErrQueueShutdown + } + + return nil +} + +// NewFifoScheduler for create new Scheduler instance +func NewFifoScheduler(queueSize int, logger logging.Logger) Scheduler { + w := &fifoScheduler{ + taskQueue: make([]Task, 2), + capacity: queueSize, + logger: logger, + } + + return w +} + +// Implement heap.Interface +func (h taskHeap) Len() int { + return len(h) +} + +func (h taskHeap) Less(i, j int) bool { + return h[i].ResumeTime() < h[j].ResumeTime() +} + +func (h taskHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *taskHeap) Push(x any) { + *h = append(*h, x.(Task)) +} + +func (h *taskHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/pkg/queue/task.go b/pkg/queue/task.go new file mode 100644 index 00000000..5318faa6 --- /dev/null +++ b/pkg/queue/task.go @@ -0,0 +1,526 @@ +package queue + +import ( + "context" + "encoding/gob" + "errors" + "fmt" + "sync" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +type ( + Task interface { + Do(ctx context.Context) (task.Status, error) + + // ID returns the Task ID + ID() int + // Type returns the Task type + Type() string + // Status returns the Task status + Status() task.Status + // Owner returns the Task owner + Owner() *ent.User + // State returns the internal Task state + State() string + // ShouldPersist returns true if the Task should be persisted into DB + ShouldPersist() bool + // Persisted returns true if the Task is persisted in DB + Persisted() bool + // Executed returns the duration of the Task execution + Executed() time.Duration + // Retried returns the number of times the Task has been retried + Retried() int + // Error returns the error of the Task + Error() error + // ErrorHistory returns the error history of the Task + ErrorHistory() []error + // Model returns the ent model of the Task + Model() *ent.Task + // CorrelationID returns the correlation ID of the Task + CorrelationID() uuid.UUID + // ResumeTime returns the time when the Task is resumed + ResumeTime() int64 + // ResumeAfter sets the time when the Task should be resumed + ResumeAfter(next time.Duration) + Progress(ctx context.Context) Progresses + // Summarize returns the Task summary for UI display + Summarize(hasher hashid.Encoder) *Summary + // OnSuspend is called when queue decides to suspend the Task + OnSuspend(time int64) + // OnPersisted is called when the Task is persisted or updated in DB + OnPersisted(task *ent.Task) + // OnError is called when the Task encounters an error + OnError(err error, d time.Duration) + // OnRetry is called when the iteration returns error and before retry + OnRetry(err error) + // OnIterationComplete is called when the one iteration is completed + OnIterationComplete(executed time.Duration) + // OnStatusTransition is called when the Task status is changed + OnStatusTransition(newStatus task.Status) + + // Cleanup is called when the Task is done or error. + Cleanup(ctx context.Context) error + + Lock() + Unlock() + } + ResumableTaskFactory func(model *ent.Task) Task + Progress struct { + Total int64 `json:"total"` + Current int64 `json:"current"` + Identifier string `json:"identifier"` + } + Progresses map[string]*Progress + Summary struct { + NodeID int `json:"-"` + Phase string `json:"phase,omitempty"` + Props map[string]any `json:"props,omitempty"` + } + + stateTransition func(ctx context.Context, task Task, newStatus task.Status, q *queue) error +) + +var ( + taskFactories sync.Map +) + +const ( + MediaMetaTaskType = "media_meta" + EntityRecycleRoutineTaskType = "entity_recycle_routine" + ExplicitEntityRecycleTaskType = "explicit_entity_recycle" + UploadSentinelCheckTaskType = "upload_sentinel_check" + CreateArchiveTaskType = "create_archive" + ExtractArchiveTaskType = "extract_archive" + RelocateTaskType = "relocate" + RemoteDownloadTaskType = "remote_download" + + SlaveCreateArchiveTaskType = "slave_create_archive" + SlaveUploadTaskType = "slave_upload" + SlaveExtractArchiveType = "slave_extract_archive" +) + +func init() { + gob.Register(Progresses{}) +} + +// RegisterResumableTaskFactory registers a resumable Task factory +func RegisterResumableTaskFactory(taskType string, factory ResumableTaskFactory) { + taskFactories.Store(taskType, factory) +} + +// NewTaskFromModel creates a Task from ent.Task model +func NewTaskFromModel(model *ent.Task) (Task, error) { + if factory, ok := taskFactories.Load(model.Type); ok { + return factory.(ResumableTaskFactory)(model), nil + } + + return nil, fmt.Errorf("unknown Task type: %s", model.Type) +} + +// InMemoryTask implements part Task interface using in-memory data. +type InMemoryTask struct { + *DBTask +} + +func (i *InMemoryTask) ShouldPersist() bool { + return false +} + +func (t *InMemoryTask) OnStatusTransition(newStatus task.Status) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + t.Task.Status = newStatus + } +} + +// DBTask implements Task interface related to DB schema +type DBTask struct { + DirectOwner *ent.User + Task *ent.Task + + mu sync.Mutex +} + +func (t *DBTask) ID() int { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.ID + } + return 0 +} + +func (t *DBTask) Status() task.Status { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.Status + } + return "" +} + +func (t *DBTask) Type() string { + t.mu.Lock() + defer t.mu.Unlock() + return t.Task.Type +} + +func (t *DBTask) Owner() *ent.User { + t.mu.Lock() + defer t.mu.Unlock() + + if t.DirectOwner != nil { + return t.DirectOwner + } + if t.Task != nil { + return t.Task.Edges.User + } + return nil +} + +func (t *DBTask) State() string { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.PrivateState + } + return "" +} + +func (t *DBTask) Persisted() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.Task != nil && t.Task.ID != 0 +} + +func (t *DBTask) Executed() time.Duration { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.PublicState.ExecutedDuration + } + return 0 +} + +func (t *DBTask) Retried() int { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.PublicState.RetryCount + } + return 0 +} + +func (t *DBTask) Error() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil && t.Task.PublicState.Error != "" { + return errors.New(t.Task.PublicState.Error) + } + + return nil +} + +func (t *DBTask) ErrorHistory() []error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return lo.Map(t.Task.PublicState.ErrorHistory, func(err string, index int) error { + return errors.New(err) + }) + } + + return nil +} + +func (t *DBTask) Model() *ent.Task { + t.mu.Lock() + defer t.mu.Unlock() + return t.Task +} + +func (t *DBTask) Cleanup(ctx context.Context) error { + return nil +} + +func (t *DBTask) CorrelationID() uuid.UUID { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.CorrelationID + } + return uuid.Nil +} + +func (t *DBTask) ShouldPersist() bool { + return true +} + +func (t *DBTask) OnPersisted(task *ent.Task) { + t.mu.Lock() + defer t.mu.Unlock() + + t.Task = task +} + +func (t *DBTask) OnError(err error, d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + t.Task.PublicState.Error = err.Error() + t.Task.PublicState.ExecutedDuration += d + } +} + +func (t *DBTask) OnRetry(err error) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + if t.Task.PublicState.ErrorHistory == nil { + t.Task.PublicState.ErrorHistory = make([]string, 0) + } + + t.Task.PublicState.ErrorHistory = append(t.Task.PublicState.ErrorHistory, err.Error()) + t.Task.PublicState.RetryCount++ + } +} + +func (t *DBTask) OnIterationComplete(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + t.Task.PublicState.ExecutedDuration += d + } +} + +func (t *DBTask) ResumeTime() int64 { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + return t.Task.PublicState.ResumeTime + } + return 0 +} + +func (t *DBTask) OnSuspend(time int64) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + t.Task.PublicState.ResumeTime = time + } +} + +func (t *DBTask) Progress(ctx context.Context) Progresses { + return nil +} + +func (t *DBTask) OnStatusTransition(newStatus task.Status) { + // Nop +} + +func (t *DBTask) Lock() { + t.mu.Lock() +} + +func (t *DBTask) Unlock() { + t.mu.Unlock() +} + +func (t *DBTask) Summarize(hasher hashid.Encoder) *Summary { + return &Summary{} +} + +func (t *DBTask) ResumeAfter(next time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.Task != nil { + t.Task.PublicState.ResumeTime = time.Now().Add(next).Unix() + } +} + +var stateTransitions map[task.Status]map[task.Status]stateTransition + +func init() { + stateTransitions = map[task.Status]map[task.Status]stateTransition{ + "": { + task.StatusQueued: persistTask, + }, + task.StatusQueued: { + task.StatusProcessing: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + if err := persistTask(ctx, task, newStatus, q); err != nil { + return err + } + return nil + }, + task.StatusQueued: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + return nil + }, + task.StatusError: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.metric.IncFailureTask() + return persistTask(ctx, task, newStatus, q) + }, + }, + task.StatusProcessing: { + task.StatusQueued: persistTask, + task.StatusCompleted: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.logger.Info("Execution completed in %s with %d retries, clean up...", task.Executed(), task.Retried()) + q.metric.IncSuccessTask() + + if err := task.Cleanup(ctx); err != nil { + q.logger.Error("Task cleanup failed: %s", err.Error()) + } + + if q.registry != nil { + q.registry.Delete(task.ID()) + } + + if err := persistTask(ctx, task, newStatus, q); err != nil { + return err + } + return nil + }, + task.StatusError: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.logger.Error("Execution failed with error in %s with %d retries, clean up...", task.Executed(), task.Retried()) + q.metric.IncFailureTask() + + if err := task.Cleanup(ctx); err != nil { + q.logger.Error("Task cleanup failed: %s", err.Error()) + } + + if q.registry != nil { + q.registry.Delete(task.ID()) + } + + if err := persistTask(ctx, task, newStatus, q); err != nil { + return err + } + + return nil + }, + task.StatusCanceled: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.logger.Info("Execution canceled, clean up...", task.Executed(), task.Retried()) + q.metric.IncFailureTask() + + if err := task.Cleanup(ctx); err != nil { + q.logger.Error("Task cleanup failed: %s", err.Error()) + } + + if q.registry != nil { + q.registry.Delete(task.ID()) + } + + if err := persistTask(ctx, task, newStatus, q); err != nil { + return err + } + + return nil + }, + task.StatusProcessing: persistTask, + task.StatusSuspending: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.metric.IncSuspendingTask() + if err := persistTask(ctx, task, newStatus, q); err != nil { + return err + } + q.logger.Info("Task %d suspended, resume time: %d", task.ID(), task.ResumeTime()) + return q.QueueTask(ctx, task) + }, + }, + task.StatusSuspending: { + task.StatusProcessing: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.metric.DecSuspendingTask() + return persistTask(ctx, task, newStatus, q) + }, + task.StatusError: func(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + q.metric.IncFailureTask() + return persistTask(ctx, task, newStatus, q) + }, + }, + } + +} + +func persistTask(ctx context.Context, task Task, newState task.Status, q *queue) error { + // Persist Task into inventory + if task.ShouldPersist() { + if err := saveTaskToInventory(ctx, task, newState, q); err != nil { + return err + } + } else { + task.OnStatusTransition(newState) + } + + return nil +} + +func saveTaskToInventory(ctx context.Context, task Task, newStatus task.Status, q *queue) error { + var ( + errStr string + errHistory []string + ) + if err := task.Error(); err != nil { + errStr = err.Error() + } + + errHistory = lo.Map(task.ErrorHistory(), func(err error, index int) string { + return err.Error() + }) + + args := &inventory.TaskArgs{ + Status: newStatus, + Type: task.Type(), + PublicState: &types.TaskPublicState{ + RetryCount: task.Retried(), + ExecutedDuration: task.Executed(), + ErrorHistory: errHistory, + Error: errStr, + ResumeTime: task.ResumeTime(), + }, + PrivateState: task.State(), + OwnerID: task.Owner().ID, + CorrelationID: logging.CorrelationID(ctx), + } + + var ( + res *ent.Task + err error + ) + + if !task.Persisted() { + res, err = q.taskClient.New(ctx, args) + } else { + res, err = q.taskClient.Update(ctx, task.Model(), args) + } + if err != nil { + return fmt.Errorf("failed to persist Task into DB: %w", err) + } + + task.OnPersisted(res) + return nil +} diff --git a/pkg/queue/thread.go b/pkg/queue/thread.go new file mode 100644 index 00000000..473c351e --- /dev/null +++ b/pkg/queue/thread.go @@ -0,0 +1,24 @@ +package queue + +import "sync" + +type routineGroup struct { + waitGroup sync.WaitGroup +} + +func newRoutineGroup() *routineGroup { + return new(routineGroup) +} + +func (g *routineGroup) Run(fn func()) { + g.waitGroup.Add(1) + + go func() { + defer g.waitGroup.Done() + fn() + }() +} + +func (g *routineGroup) Wait() { + g.waitGroup.Wait() +} diff --git a/pkg/recaptcha/recaptcha.go b/pkg/recaptcha/recaptcha.go index 75354bde..e3608980 100644 --- a/pkg/recaptcha/recaptcha.go +++ b/pkg/recaptcha/recaptcha.go @@ -67,7 +67,8 @@ type ReCAPTCHA struct { } // NewReCAPTCHA new ReCAPTCHA instance if version is set to V2 uses recatpcha v2 API, get your secret from https://www.google.com/recaptcha/admin -// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3 +// +// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3 func NewReCAPTCHA(ReCAPTCHASecret string, version VERSION, timeout time.Duration) (ReCAPTCHA, error) { if ReCAPTCHASecret == "" { return ReCAPTCHA{}, fmt.Errorf("recaptcha secret cannot be blank") diff --git a/pkg/request/options.go b/pkg/request/options.go index 63bc8dd8..e5b5d844 100644 --- a/pkg/request/options.go +++ b/pkg/request/options.go @@ -2,7 +2,8 @@ package request import ( "context" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" "net/http" "net/url" "strings" @@ -15,18 +16,24 @@ type Option interface { } type options struct { - timeout time.Duration - header http.Header - sign auth.Auth - signTTL int64 - ctx context.Context - contentLength int64 - masterMeta bool - endpoint *url.URL - slaveNodeID string - tpsLimiterToken string - tps float64 - tpsBurst int + timeout time.Duration + header http.Header + sign auth.Auth + signTTL int64 + ctx context.Context + contentLength int64 + masterMeta bool + siteID string + siteURL string + endpoint *url.URL + slaveNodeID int + tpsLimiterToken string + tps float64 + tpsBurst int + logger logging.Logger + withCorrelationID bool + cookieJar http.CookieJar + transport *http.Transport } type optionFunc func(*options) @@ -38,7 +45,7 @@ func (f optionFunc) apply(o *options) { func newDefaultOption() *options { return &options{ header: http.Header{}, - timeout: time.Duration(30) * time.Second, + timeout: 0, contentLength: -1, ctx: context.Background(), } @@ -50,6 +57,13 @@ func (o *options) clone() options { return newOptions } +// WithTransport 设置请求Transport +func WithTransport(transport *http.Transport) Option { + return optionFunc(func(o *options) { + o.transport = transport + }) +} + // WithTimeout 设置请求超时 func WithTimeout(t time.Duration) Option { return optionFunc(func(o *options) { @@ -68,7 +82,9 @@ func WithContext(c context.Context) Option { func WithCredential(instance auth.Auth, ttl int64) Option { return optionFunc(func(o *options) { o.sign = instance - o.signTTL = ttl + if ttl > 0 { + o.signTTL = ttl + } }) } @@ -99,14 +115,16 @@ func WithContentLength(s int64) Option { } // WithMasterMeta 请求时携带主机信息 -func WithMasterMeta() Option { +func WithMasterMeta(siteID string, siteURL string) Option { return optionFunc(func(o *options) { o.masterMeta = true + o.siteID = siteID + o.siteURL = siteURL }) } -// WithSlaveMeta 请求时携带从机信息 -func WithSlaveMeta(s string) Option { +// WithSlaveMeta set slave node ID in master's request header +func WithSlaveMeta(s int) Option { return optionFunc(func(o *options) { o.slaveNodeID = s }) @@ -135,3 +153,24 @@ func WithTPSLimit(token string, tps float64, burst int) Option { o.tpsBurst = burst }) } + +// WithLogger set logger for logging requests +func WithLogger(logger logging.Logger) Option { + return optionFunc(func(o *options) { + o.logger = logger + }) +} + +// WithCorrelationID set correlation ID for logging requests +func WithCorrelationID() Option { + return optionFunc(func(o *options) { + o.withCorrelationID = true + }) +} + +// WithCookieJar set cookie jar for request +func WithCookieJar(jar http.CookieJar) Option { + return optionFunc(func(o *options) { + o.cookieJar = jar + }) +} diff --git a/pkg/request/request.go b/pkg/request/request.go index 29470852..c5a8301b 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -1,34 +1,50 @@ package request import ( + "context" "encoding/json" "errors" "fmt" "io" - "io/ioutil" "net/http" "net/url" + "strconv" "strings" "sync" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/samber/lo" ) // GeneralClient 通用 HTTP Client -var GeneralClient Client = NewClient() +var GeneralClient Client = NewClientDeprecated() + +const ( + CorrelationHeader = constants.CrHeaderPrefix + "Correlation-Id" + SiteURLHeader = constants.CrHeaderPrefix + "Site-Url" + SiteVersionHeader = constants.CrHeaderPrefix + "Version" + SiteIDHeader = constants.CrHeaderPrefix + "Site-Id" + SlaveNodeIDHeader = constants.CrHeaderPrefix + "Node-Id" + LocalIP = "localhost" +) // Response 请求的响应或错误信息 type Response struct { Err error Response *http.Response + l logging.Logger } // Client 请求客户端 type Client interface { + // Apply applies the given options to this client. + Apply(opts ...Option) + // Request send a HTTP request Request(method, target string, body io.Reader, opts ...Option) *Response } @@ -37,9 +53,26 @@ type HTTPClient struct { mu sync.Mutex options *options tpsLimiter TPSLimiter + l logging.Logger + config conf.ConfigProvider +} + +func NewClient(config conf.ConfigProvider, opts ...Option) Client { + client := &HTTPClient{ + options: newDefaultOption(), + tpsLimiter: globalTPSLimiter, + config: config, + } + + for _, o := range opts { + o.apply(client.options) + } + + return client } -func NewClient(opts ...Option) Client { +// Deprecated +func NewClientDeprecated(opts ...Option) Client { client := &HTTPClient{ options: newDefaultOption(), tpsLimiter: globalTPSLimiter, @@ -52,6 +85,12 @@ func NewClient(opts ...Option) Client { return client } +func (c *HTTPClient) Apply(opts ...Option) { + for _, o := range opts { + o.apply(c.options) + } +} + // Request 发送HTTP请求 func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 @@ -63,7 +102,14 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti } // 创建请求客户端 - client := &http.Client{Timeout: options.timeout} + client := &http.Client{ + Timeout: options.timeout, + Jar: options.cookieJar, + } + + if options.transport != nil { + client.Transport = options.transport + } // size为0时将body设为nil if options.contentLength == 0 { @@ -86,6 +132,7 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti req *http.Request err error ) + start := time.Now() if options.ctx != nil { req, err = http.NewRequestWithContext(options.ctx, method, target, body) } else { @@ -102,14 +149,21 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti } } - if options.masterMeta && conf.SystemConfig.Mode == "master" { - req.Header.Add(auth.CrHeaderPrefix+"Site-Url", model.GetSiteURL().String()) - req.Header.Add(auth.CrHeaderPrefix+"Site-Id", model.GetSettingByName("siteID")) - req.Header.Add(auth.CrHeaderPrefix+"Cloudreve-Version", conf.BackendVersion) + req.Header.Set("User-Agent", "Cloudreve/"+constants.BackendVersion) + + if options.ctx != nil && options.withCorrelationID { + req.Header.Add(CorrelationHeader, logging.CorrelationID(options.ctx).String()) + } + + mode := c.config.System().Mode + if options.masterMeta && mode == conf.MasterMode { + req.Header.Add(SiteURLHeader, options.siteURL) + req.Header.Add(SiteIDHeader, options.siteID) + req.Header.Add(SiteVersionHeader, constants.BackendVersion) } - if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" { - req.Header.Add(auth.CrHeaderPrefix+"Node-Id", options.slaveNodeID) + if options.slaveNodeID > 0 { + req.Header.Add(SlaveNodeIDHeader, strconv.Itoa(options.slaveNodeID)) } if options.contentLength != -1 { @@ -118,11 +172,16 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti // 签名请求 if options.sign != nil { + ctx := options.ctx + if options.ctx == nil { + ctx = context.Background() + } + expire := time.Now().Add(time.Second * time.Duration(options.signTTL)) switch method { case "PUT", "POST", "PATCH": - auth.SignRequest(options.sign, req, options.signTTL) + auth.SignRequest(ctx, options.sign, req, &expire) default: - if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil { + if resURL, err := auth.SignURI(ctx, options.sign, req.URL.String(), &expire); err == nil { req.URL = resURL } } @@ -134,11 +193,32 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti // 发送请求 resp, err := client.Do(req) + + // Logging request + if options.logger != nil { + statusCode := 0 + errStr := "" + if resp != nil { + statusCode = resp.StatusCode + } + + if err != nil { + errStr = err.Error() + } + + logging.Request(options.logger, false, statusCode, req.Method, LocalIP, req.URL.String(), errStr, start) + } + + // Apply cookies + if resp != nil && resp.Cookies() != nil && options.cookieJar != nil { + options.cookieJar.SetCookies(req.URL, resp.Cookies()) + } + if err != nil { return &Response{Err: err} } - return &Response{Err: nil, Response: resp} + return &Response{Err: nil, Response: resp, l: options.logger} } // GetResponse 检查响应并获取响应正文 @@ -146,21 +226,33 @@ func (resp *Response) GetResponse() (string, error) { if resp.Err != nil { return "", resp.Err } - respBody, err := ioutil.ReadAll(resp.Response.Body) + respBody, err := io.ReadAll(resp.Response.Body) _ = resp.Response.Body.Close() return string(respBody), err } +// GetResponseIgnoreErr 获取响应正文 +func (resp *Response) GetResponseIgnoreErr() (string, error) { + if resp.Response == nil { + return "", resp.Err + } + + respBody, _ := io.ReadAll(resp.Response.Body) + _ = resp.Response.Body.Close() + + return string(respBody), resp.Err +} + // CheckHTTPResponse 检查请求响应HTTP状态码 -func (resp *Response) CheckHTTPResponse(status int) *Response { +func (resp *Response) CheckHTTPResponse(status ...int) *Response { if resp.Err != nil { return resp } // 检查HTTP状态码 - if resp.Response.StatusCode != status { - resp.Err = fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode) + if !lo.Contains(status, resp.Response.StatusCode) { + resp.Err = fmt.Errorf("Remote returns unexpected status code: %d", resp.Response.StatusCode) } return resp } @@ -179,7 +271,10 @@ func (resp *Response) DecodeResponse() (*serializer.Response, error) { var res serializer.Response err = json.Unmarshal([]byte(respString), &res) if err != nil { - util.Log().Debug("Failed to parse response: %s", string(respString)) + if resp.l != nil { + resp.l.Debug("Failed to parse response: %s", respString) + } + return nil, err } return &res, nil @@ -254,10 +349,3 @@ func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) { return 0, errors.New("not implemented") } - -// BlackHole 将客户端发来的数据放入黑洞 -func BlackHole(r io.Reader) { - if !model.IsTrueVal(model.GetSettingByName("reset_after_upload_failed")) { - io.Copy(ioutil.Discard, r) - } -} diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go index e54831e2..4e062df9 100644 --- a/pkg/request/request_test.go +++ b/pkg/request/request_test.go @@ -3,17 +3,16 @@ package request import ( "context" "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" "io" "io/ioutil" "net/http" "strings" "testing" "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" ) type ClientMock struct { @@ -55,7 +54,7 @@ func TestWithContext(t *testing.T) { func TestHTTPClient_Request(t *testing.T) { asserts := assert.New(t) - client := NewClient(WithSlaveMeta("test")) + client := NewClientDeprecated(WithSlaveMeta("test")) // 正常 { @@ -63,8 +62,6 @@ func TestHTTPClient_Request(t *testing.T) { "POST", "/test", strings.NewReader(""), - WithContentLength(0), - WithEndpoint("http://cloudreveisnotexist.com"), WithTimeout(time.Duration(1)*time.Microsecond), WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), WithoutHeader([]string{"origin", "origin"}), @@ -79,11 +76,11 @@ func TestHTTPClient_Request(t *testing.T) { "GET", "http://cloudreveisnotexist.com", strings.NewReader(""), + WithContentLength(0), + WithEndpoint("http://cloudreveisnotexist.com"), WithTimeout(time.Duration(1)*time.Microsecond), WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), WithContext(context.Background()), - WithoutHeader([]string{"s s", "s s"}), - WithMasterMeta(), ) asserts.Error(resp.Err) asserts.Nil(resp.Response) @@ -241,7 +238,7 @@ func TestBlackHole(t *testing.T) { func TestHTTPClient_TPSLimit(t *testing.T) { a := assert.New(t) - client := NewClient() + client := NewClientDeprecated() finished := make(chan struct{}) go func() { diff --git a/pkg/request/utils.go b/pkg/request/utils.go new file mode 100644 index 00000000..b87303c2 --- /dev/null +++ b/pkg/request/utils.go @@ -0,0 +1,62 @@ +package request + +import ( + "io" + "net/http" + "strconv" +) + +var contentLengthHeaders = []string{ + "Content-Length", + "X-Expected-Entity-Length", // DavFS on MacOS +} + +// BlackHole 将客户端发来的数据放入黑洞 +func BlackHole(r io.Reader) { + io.Copy(io.Discard, r) +} + +// SniffContentLength tries to get the content length from the request. It also returns +// a reader that will limit to the sniffed content length. +func SniffContentLength(r *http.Request) (LimitReaderCloser, int64, error) { + for _, header := range contentLengthHeaders { + if length := r.Header.Get(header); length != "" { + res, err := strconv.ParseInt(length, 10, 64) + if err != nil { + return nil, 0, err + } + + return newLimitReaderCloser(r.Body, res), res, nil + } + } + return newLimitReaderCloser(r.Body, 0), 0, nil +} + +type LimitReaderCloser interface { + io.Reader + io.Closer + Count() int64 +} + +type limitReaderCloser struct { + io.Reader + io.Closer + read int64 +} + +func newLimitReaderCloser(r io.ReadCloser, limit int64) LimitReaderCloser { + return &limitReaderCloser{ + Reader: io.LimitReader(r, limit), + Closer: r, + } +} + +func (l *limitReaderCloser) Read(p []byte) (n int, err error) { + n, err = l.Reader.Read(p) + l.read += int64(n) + return n, err +} + +func (l *limitReaderCloser) Count() int64 { + return l.read +} diff --git a/pkg/serializer/aria2.go b/pkg/serializer/aria2.go deleted file mode 100644 index 890b2b9e..00000000 --- a/pkg/serializer/aria2.go +++ /dev/null @@ -1,117 +0,0 @@ -package serializer - -import ( - "path" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" -) - -// DownloadListResponse 下载列表响应条目 -type DownloadListResponse struct { - UpdateTime time.Time `json:"update"` - UpdateInterval int `json:"interval"` - Name string `json:"name"` - Status int `json:"status"` - Dst string `json:"dst"` - Total uint64 `json:"total"` - Downloaded uint64 `json:"downloaded"` - Speed int `json:"speed"` - Info rpc.StatusInfo `json:"info"` - NodeName string `json:"node"` -} - -// FinishedListResponse 已完成任务条目 -type FinishedListResponse struct { - Name string `json:"name"` - GID string `json:"gid"` - Status int `json:"status"` - Dst string `json:"dst"` - Error string `json:"error"` - Total uint64 `json:"total"` - Files []rpc.FileInfo `json:"files"` - TaskStatus int `json:"task_status"` - TaskError string `json:"task_error"` - CreateTime time.Time `json:"create"` - UpdateTime time.Time `json:"update"` - NodeName string `json:"node"` -} - -// BuildFinishedListResponse 构建已完成任务条目 -func BuildFinishedListResponse(tasks []model.Download) Response { - resp := make([]FinishedListResponse, 0, len(tasks)) - - for i := 0; i < len(tasks); i++ { - fileName := tasks[i].StatusInfo.BitTorrent.Info.Name - if len(tasks[i].StatusInfo.Files) == 1 { - fileName = path.Base(tasks[i].StatusInfo.Files[0].Path) - } - - // 过滤敏感信息 - for i2 := 0; i2 < len(tasks[i].StatusInfo.Files); i2++ { - tasks[i].StatusInfo.Files[i2].Path = path.Base(tasks[i].StatusInfo.Files[i2].Path) - } - - download := FinishedListResponse{ - Name: fileName, - GID: tasks[i].GID, - Status: tasks[i].Status, - Error: tasks[i].Error, - Dst: tasks[i].Dst, - Total: tasks[i].TotalSize, - Files: tasks[i].StatusInfo.Files, - TaskStatus: -1, - UpdateTime: tasks[i].UpdatedAt, - CreateTime: tasks[i].CreatedAt, - NodeName: tasks[i].NodeName, - } - - if tasks[i].Task != nil { - download.TaskError = tasks[i].Task.Error - download.TaskStatus = tasks[i].Task.Status - } - - resp = append(resp, download) - } - - return Response{Data: resp} -} - -// BuildDownloadingResponse 构建正在下载的列表响应 -func BuildDownloadingResponse(tasks []model.Download, intervals map[uint]int) Response { - resp := make([]DownloadListResponse, 0, len(tasks)) - - for i := 0; i < len(tasks); i++ { - fileName := "" - if len(tasks[i].StatusInfo.Files) > 0 { - fileName = path.Base(tasks[i].StatusInfo.Files[0].Path) - } - - // 过滤敏感信息 - tasks[i].StatusInfo.Dir = "" - for i2 := 0; i2 < len(tasks[i].StatusInfo.Files); i2++ { - tasks[i].StatusInfo.Files[i2].Path = path.Base(tasks[i].StatusInfo.Files[i2].Path) - } - - interval := 10 - if actualInterval, ok := intervals[tasks[i].ID]; ok { - interval = actualInterval - } - - resp = append(resp, DownloadListResponse{ - UpdateTime: tasks[i].UpdatedAt, - UpdateInterval: interval, - Name: fileName, - Status: tasks[i].Status, - Dst: tasks[i].Dst, - Total: tasks[i].TotalSize, - Downloaded: tasks[i].DownloadedSize, - Speed: tasks[i].Speed, - Info: tasks[i].StatusInfo, - NodeName: tasks[i].NodeName, - }) - } - - return Response{Data: resp} -} diff --git a/pkg/serializer/aria2_test.go b/pkg/serializer/aria2_test.go deleted file mode 100644 index 1f3ca614..00000000 --- a/pkg/serializer/aria2_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestBuildFinishedListResponse(t *testing.T) { - asserts := assert.New(t) - tasks := []model.Download{ - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name.txt", - }, - }, - }, - Task: &model.Task{ - Model: gorm.Model{}, - Error: "error", - }, - }, - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name1.txt", - }, - { - Path: "/file/name2.txt", - }, - }, - }, - }, - } - tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" - res := BuildFinishedListResponse(tasks).Data.([]FinishedListResponse) - asserts.Len(res, 2) - asserts.Equal("name.txt", res[1].Name) - asserts.Equal("name.txt", res[0].Name) - asserts.Equal("name.txt", res[0].Files[0].Path) - asserts.Equal("name1.txt", res[1].Files[0].Path) - asserts.Equal("name2.txt", res[1].Files[1].Path) - asserts.EqualValues(0, res[0].TaskStatus) - asserts.Equal("error", res[0].TaskError) -} - -func TestBuildDownloadingResponse(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_aria2_interval", "10", 0) - tasks := []model.Download{ - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name.txt", - }, - }, - }, - Task: &model.Task{ - Model: gorm.Model{}, - Error: "error", - }, - }, - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name1.txt", - }, - { - Path: "/file/name2.txt", - }, - }, - }, - }, - } - tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" - tasks[1].ID = 1 - - res := BuildDownloadingResponse(tasks, map[uint]int{1: 5}).Data.([]DownloadListResponse) - asserts.Len(res, 2) - asserts.Equal("name1.txt", res[1].Name) - asserts.Equal(5, res[1].UpdateInterval) - asserts.Equal("name.txt", res[0].Name) - asserts.Equal("name.txt", res[0].Info.Files[0].Path) - asserts.Equal("name1.txt", res[1].Info.Files[0].Path) - asserts.Equal("name2.txt", res[1].Info.Files[1].Path) -} diff --git a/pkg/serializer/auth_test.go b/pkg/serializer/auth_test.go deleted file mode 100644 index 96b6b9b1..00000000 --- a/pkg/serializer/auth_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package serializer - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewRequestSignString(t *testing.T) { - asserts := assert.New(t) - - sign := NewRequestSignString("1", "2", "3") - asserts.NotEmpty(sign) -} diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 326c0d87..f70a7adf 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -1,8 +1,15 @@ package serializer import ( + "context" "errors" + "fmt" + "strings" + + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) // AppError 应用错误,实现了error接口 @@ -32,15 +39,33 @@ func NewErrorFromResponse(resp *Response) AppError { // WithError 将应用error携带标准库中的error func (err *AppError) WithError(raw error) AppError { - err.RawError = raw - return *err + return AppError{ + Code: err.Code, + Msg: err.Msg, + RawError: raw, + } } // Error 返回业务代码确定的可读错误信息 func (err AppError) Error() string { + if err.RawError != nil { + return fmt.Sprintf("%s: %s", err.Msg, err.RawError.Error()) + } return err.Msg } +func (err AppError) ErrCode() int { + var inheritedErr AppError + if errors.As(err.RawError, &inheritedErr) { + return inheritedErr.ErrCode() + } + return err.Code +} + +func (err AppError) Unwrap() error { + return err.RawError +} + // 三位数错误编码为复用http原本含义 // 五位数错误编码为应用自定义错误 // 五开头的五位数错误编码为服务器端错误,比如数据库操作失败 @@ -78,6 +103,8 @@ const ( CodeInvalidChunkIndex = 40012 // CodeInvalidContentLength 无效的正文长度 CodeInvalidContentLength = 40013 + // CodePhoneRequired 未绑定手机 + CodePhoneRequired = 40010 // CodeBatchSourceSize 超出批量获取外链限制 CodeBatchSourceSize = 40014 // CodeBatchAria2Size 超出最大 Aria2 任务数量限制 @@ -112,6 +139,8 @@ const ( CodeInvalidTempLink = 40029 // CodeTempLinkExpired 临时链接过期 CodeTempLinkExpired = 40030 + // CodeEmailProviderBaned 邮箱后缀被禁用 + CodeEmailProviderBaned = 40031 // CodeEmailExisted 邮箱已被使用 CodeEmailExisted = 40032 // CodeEmailSent 邮箱已重新发送 @@ -180,18 +209,50 @@ const ( CodeGroupInvalid = 40064 // 兑换码无效 CodeInvalidGiftCode = 40065 - // 已绑定了QQ账号 - CodeQQBindConflict = 40066 - // QQ账号已被绑定其他账号 - CodeQQBindOtherAccount = 40067 - // QQ 未绑定对应账号 - CodeQQNotLinked = 40068 + // 已绑定了对应账号 + CodeOpenIDBindConflict = 40066 + // 对应账号已被绑定其他账号 + CodeOpenIDBindOtherAccount = 40067 + // 未绑定对应账号 + CodeOpenIDNotLinked = 40068 // 密码不正确 CodeIncorrectPassword = 40069 // 分享无法预览 CodeDisabledSharePreview = 40070 // 签名无效 CodeInvalidSign = 40071 + // 管理员无法购买用户组 + CodeFulfillAdminGroup = 40072 + // Lock confliced + CodeLockConflict = 40073 + // Too many uris + CodeTooManyUris = 40074 + // Lock token expired + CodeLockExpired = 40075 + // Current updated version is stale + CodeStaleVersion = 40076 + // CodeEntityNotExist Entity not exist + CodeEntityNotExist = 40077 + // CodeFileDeleted File is deleted in recycle bin + CodeFileDeleted = 40078 + // CodeFileCountLimitedReached file count limited reached + CodeFileCountLimitedReached = 40079 + // CodeInvalidPassword invalid password + CodeInvalidPassword = 40080 + // CodeBatchOperationNotFullyCompleted batch operation not fully completed + CodeBatchOperationNotFullyCompleted = 40081 + // CodeOwnerOnly owner operation only + CodeOwnerOnly = 40082 + // CodePurchaseRequired purchase required + CodePurchaseRequired = 40083 + // CodeManagedAccountMinimumOpenID managed account minimum openid + CodeManagedAccountMinimumOpenID = 40084 + // CodeAmountTooSmall amount too small + CodeAmountTooSmall = 40085 + // CodeNodeUsedByStoragePolicy node used by storage policy + CodeNodeUsedByStoragePolicy = 40086 + // CodeDomainNotLicensed domain not licensed + CodeDomainNotLicensed = 40087 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 @@ -218,24 +279,41 @@ const ( CodeNotSet = -1 ) -// DBErr 数据库操作失败 -func DBErr(msg string, err error) Response { +// DBErrDeprecated 数据库操作失败 +func DBErr(c context.Context, msg string, err error) Response { if msg == "" { msg = "Database operation failed." } - return Err(CodeDBError, msg, err) + return ErrWithDetails(c, CodeDBError, msg, err) +} + +// DBErrDeprecated 数据库操作失败 +func DBErrDeprecated(msg string, err error) Response { + if msg == "" { + msg = "Database operation failed." + } + return ErrDeprecated(CodeDBError, msg, err) } // ParamErr 各种参数错误 -func ParamErr(msg string, err error) Response { +func ParamErr(c context.Context, msg string, err error) Response { + if msg == "" { + msg = "Invalid parameters." + } + return ErrWithDetails(c, CodeParamErr, msg, err) +} + +// ParamErrDeprecated 各种参数错误 +// Deprecated +func ParamErrDeprecated(msg string, err error) Response { if msg == "" { msg = "Invalid parameters." } - return Err(CodeParamErr, msg, err) + return ErrDeprecated(CodeParamErr, msg, err) } -// Err 通用错误处理 -func Err(errCode int, msg string, err error) Response { +// ErrDeprecated 通用错误处理 +func ErrDeprecated(errCode int, msg string, err error) Response { // 底层错误是AppError,则尝试从AppError中获取详细信息 var appError AppError if errors.As(err, &appError) { @@ -254,3 +332,131 @@ func Err(errCode int, msg string, err error) Response { } return res } + +// ErrWithDetails 通用错误处理 +func ErrWithDetails(c context.Context, errCode int, msg string, err error) Response { + res := Response{ + Code: errCode, + Msg: msg, + CorrelationID: logging.CorrelationID(c).String(), + } + + // 底层错误是AppError,则尝试从AppError中获取详细信息 + var appError AppError + if errors.As(err, &appError) { + res.Code = appError.ErrCode() + err = appError.RawError + res.Msg = appError.Msg + + // Special case for error with detail data + switch res.Code { + case CodeLockConflict: + var lockConflict lock.ConflictError + if errors.As(err, &lockConflict) { + res.Data = lockConflict + } + case CodeBatchOperationNotFullyCompleted: + var errs *AggregateError + if errors.As(err, &errs) { + res.AggregatedError = errs.Expand(c) + } + } + } + + // 生产环境隐藏底层报错 + if err != nil && gin.Mode() != gin.ReleaseMode { + res.Error = err.Error() + } + + return res +} + +// Err Builds error response without addition details, code and message will +// be retrieved from error if possible +func Err(c context.Context, err error) Response { + return ErrWithDetails(c, CodeNotSet, "", err) +} + +// AggregateError is a special error type that contains multiple errors +type AggregateError struct { + errs map[string]error +} + +// NewAggregateError creates a new AggregateError +func NewAggregateError() *AggregateError { + return &AggregateError{ + errs: make(map[string]error, 0), + } +} + +func (e *AggregateError) Error() string { + return fmt.Sprintf("aggregate error: one or more operation failed") +} + +// Add adds an error to the aggregate +func (e *AggregateError) Add(id string, err error) { + e.errs[id] = err +} + +// Merge merges another aggregate error into this one +func (e *AggregateError) Merge(err error) bool { + var errs *AggregateError + if errors.As(err, &errs) { + for id, err := range errs.errs { + e.errs[id] = err + } + + return true + } + return false +} + +// Raw returns the raw error map +func (e *AggregateError) Raw() map[string]error { + return e.errs +} + +func (e *AggregateError) Remove(id string) { + delete(e.errs, id) +} + +// Expand expands the aggregate error into a list of responses +func (e *AggregateError) Expand(ctx context.Context) map[string]Response { + return lo.MapEntries(e.errs, func(id string, err error) (string, Response) { + return id, Err(ctx, err) + }) +} + +// Aggregate aggregates the error and returns nil if there is no error; +// otherwise returns the error itself +func (e *AggregateError) Aggregate() error { + if len(e.errs) == 0 { + return nil + } + + msg := "One or more operation failed" + if len(e.errs) == 1 { + for _, err := range e.errs { + msg = err.Error() + } + } + + return NewError(CodeBatchOperationNotFullyCompleted, msg, e) +} + +func (e *AggregateError) FormatFirstN(n int) string { + if len(e.errs) == 0 { + return "" + } + + res := make([]string, 0, n) + for id, err := range e.errs { + res = append(res, fmt.Sprintf("%s: %s", id, err.Error())) + if len(res) >= n { + break + } + } + + return strings.Join(res, ", ") + +} diff --git a/pkg/serializer/error_test.go b/pkg/serializer/error_test.go deleted file mode 100644 index d02fd5dd..00000000 --- a/pkg/serializer/error_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package serializer - -import ( - "errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewError(t *testing.T) { - a := assert.New(t) - err := NewError(400, "Bad Request", errors.New("error")) - a.Error(err) - a.EqualValues(400, err.Code) - - err.WithError(errors.New("error2")) - a.Equal("error2", err.RawError.Error()) - a.Equal("Bad Request", err.Error()) - - resp := &Response{ - Code: 400, - Msg: "Bad Request", - Error: "error", - } - err = NewErrorFromResponse(resp) - a.Error(err) -} - -func TestDBErr(t *testing.T) { - a := assert.New(t) - resp := DBErr("", nil) - a.NotEmpty(resp.Msg) - - resp = ParamErr("", nil) - a.NotEmpty(resp.Msg) -} - -func TestErr(t *testing.T) { - a := assert.New(t) - err := NewError(400, "Bad Request", errors.New("error")) - resp := Err(400, "", err) - a.Equal("Bad Request", resp.Msg) -} diff --git a/pkg/serializer/explorer.go b/pkg/serializer/explorer.go deleted file mode 100644 index da3dc327..00000000 --- a/pkg/serializer/explorer.go +++ /dev/null @@ -1,132 +0,0 @@ -package serializer - -import ( - "encoding/gob" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "time" -) - -func init() { - gob.Register(ObjectProps{}) -} - -// ObjectProps 文件、目录对象的详细属性信息 -type ObjectProps struct { - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Policy string `json:"policy"` - Size uint64 `json:"size"` - ChildFolderNum int `json:"child_folder_num"` - ChildFileNum int `json:"child_file_num"` - Path string `json:"path"` - - QueryDate time.Time `json:"query_date"` -} - -// ObjectList 文件、目录列表 -type ObjectList struct { - Parent string `json:"parent,omitempty"` - Objects []Object `json:"objects"` - Policy *PolicySummary `json:"policy,omitempty"` -} - -// Object 文件或者目录 -type Object struct { - ID string `json:"id"` - Name string `json:"name"` - Path string `json:"path"` - Thumb bool `json:"thumb"` - Size uint64 `json:"size"` - Type string `json:"type"` - Date time.Time `json:"date"` - CreateDate time.Time `json:"create_date"` - Key string `json:"key,omitempty"` - SourceEnabled bool `json:"source_enabled"` -} - -// PolicySummary 用于前端组件使用的存储策略概况 -type PolicySummary struct { - ID string `json:"id"` - Name string `json:"name"` - Type string `json:"type"` - MaxSize uint64 `json:"max_size"` - FileType []string `json:"file_type"` -} - -// BuildObjectList 构建列目录响应 -func BuildObjectList(parent uint, objects []Object, policy *model.Policy) ObjectList { - res := ObjectList{ - Objects: objects, - } - - if parent > 0 { - res.Parent = hashid.HashID(parent, hashid.FolderID) - } - - if policy != nil { - res.Policy = &PolicySummary{ - ID: hashid.HashID(policy.ID, hashid.PolicyID), - Name: policy.Name, - Type: policy.Type, - MaxSize: policy.MaxSize, - FileType: policy.OptionsSerialized.FileType, - } - } - - return res -} - -// Sources 获取外链的结果响应 -type Sources struct { - URL string `json:"url"` - Name string `json:"name"` - Parent uint `json:"parent"` - Error string `json:"error,omitempty"` -} - -// DocPreviewSession 文档预览会话响应 -type DocPreviewSession struct { - URL string `json:"url"` - AccessToken string `json:"access_token,omitempty"` - AccessTokenTTL int64 `json:"access_token_ttl,omitempty"` -} - -// WopiFileInfo Response for `CheckFileInfo` -type WopiFileInfo struct { - // Required - BaseFileName string - Version string - Size int64 - - // Breadcrumb - BreadcrumbBrandName string - BreadcrumbBrandUrl string - BreadcrumbFolderName string - BreadcrumbFolderUrl string - - // Post Message - FileSharingPostMessage bool - ClosePostMessage bool - PostMessageOrigin string - - // Other miscellaneous properties - FileNameMaxLength int - LastModifiedTime string - - // User metadata - IsAnonymousUser bool - UserFriendlyName string - UserId string - OwnerId string - - // Permission - ReadOnly bool - UserCanRename bool - UserCanReview bool - UserCanWrite bool - - SupportsRename bool - SupportsReviewing bool - SupportsUpdate bool -} diff --git a/pkg/serializer/explorer_test.go b/pkg/serializer/explorer_test.go deleted file mode 100644 index 00c9efc9..00000000 --- a/pkg/serializer/explorer_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package serializer - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestBuildObjectList(t *testing.T) { - a := assert.New(t) - res := BuildObjectList(1, []Object{{}, {}}, &model.Policy{}) - a.NotEmpty(res.Parent) - a.NotNil(res.Policy) - a.Len(res.Objects, 2) -} diff --git a/pkg/serializer/response.go b/pkg/serializer/response.go index ecfaec25..7f10e1e5 100644 --- a/pkg/serializer/response.go +++ b/pkg/serializer/response.go @@ -2,24 +2,27 @@ package serializer import ( "bytes" + "context" "encoding/base64" "encoding/gob" ) // Response 基础序列化器 type Response struct { - Code int `json:"code"` - Data interface{} `json:"data,omitempty"` - Msg string `json:"msg"` - Error string `json:"error,omitempty"` + Code int `json:"code"` + Data interface{} `json:"data,omitempty"` + AggregatedError interface{} `json:"aggregated_error,omitempty"` + Msg string `json:"msg"` + Error string `json:"error,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` } // NewResponseWithGobData 返回Data字段使用gob编码的Response -func NewResponseWithGobData(data interface{}) Response { +func NewResponseWithGobData(c context.Context, data interface{}) Response { var w bytes.Buffer encoder := gob.NewEncoder(&w) if err := encoder.Encode(data); err != nil { - return Err(CodeInternalSetting, "Failed to encode response content", err) + return ErrWithDetails(c, CodeInternalSetting, "Failed to encode response content", err) } return Response{Data: w.Bytes()} diff --git a/pkg/serializer/response_test.go b/pkg/serializer/response_test.go deleted file mode 100644 index 70c88998..00000000 --- a/pkg/serializer/response_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package serializer - -import ( - "encoding/json" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewResponseWithGobData(t *testing.T) { - a := assert.New(t) - type args struct { - data interface{} - } - - res := NewResponseWithGobData(args{}) - a.Equal(CodeInternalSetting, res.Code) - - res = NewResponseWithGobData("TestNewResponseWithGobData") - a.Equal(0, res.Code) - a.NotEmpty(res.Data) -} - -func TestResponse_GobDecode(t *testing.T) { - a := assert.New(t) - res := NewResponseWithGobData("TestResponse_GobDecode") - jsonContent, err := json.Marshal(res) - a.NoError(err) - resDecoded := &Response{} - a.NoError(json.Unmarshal(jsonContent, resDecoded)) - var target string - resDecoded.GobDecode(&target) - a.Equal("TestResponse_GobDecode", target) -} diff --git a/pkg/serializer/setting.go b/pkg/serializer/setting.go index 7e4ce009..f2d6d5f1 100644 --- a/pkg/serializer/setting.go +++ b/pkg/serializer/setting.go @@ -1,92 +1,7 @@ package serializer -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "time" -) - -// SiteConfig 站点全局设置序列 -type SiteConfig struct { - SiteName string `json:"title"` - LoginCaptcha bool `json:"loginCaptcha"` - RegCaptcha bool `json:"regCaptcha"` - ForgetCaptcha bool `json:"forgetCaptcha"` - EmailActive bool `json:"emailActive"` - Themes string `json:"themes"` - DefaultTheme string `json:"defaultTheme"` - HomepageViewMethod string `json:"home_view_method"` - ShareViewMethod string `json:"share_view_method"` - Authn bool `json:"authn"` - User User `json:"user"` - ReCaptchaKey string `json:"captcha_ReCaptchaKey"` - CaptchaType string `json:"captcha_type"` - TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"` - RegisterEnabled bool `json:"registerEnabled"` - AppPromotion bool `json:"app_promotion"` - WopiExts []string `json:"wopi_exts"` -} - -type task struct { - Status int `json:"status"` - Type int `json:"type"` - CreateDate time.Time `json:"create_date"` - Progress int `json:"progress"` - Error string `json:"error"` -} - -// BuildTaskList 构建任务列表响应 -func BuildTaskList(tasks []model.Task, total int) Response { - res := make([]task, 0, len(tasks)) - for _, t := range tasks { - res = append(res, task{ - Status: t.Status, - Type: t.Type, - CreateDate: t.CreatedAt, - Progress: t.Progress, - Error: t.Error, - }) - } - - return Response{Data: map[string]interface{}{ - "total": total, - "tasks": res, - }} -} - -func checkSettingValue(setting map[string]string, key string) string { - if v, ok := setting[key]; ok { - return v - } - return "" -} - -// BuildSiteConfig 站点全局设置 -func BuildSiteConfig(settings map[string]string, user *model.User, wopiExts []string) Response { - var userRes User - if user != nil { - userRes = BuildUser(*user) - } else { - userRes = BuildUser(*model.NewAnonymousUser()) - } - res := Response{ - Data: SiteConfig{ - SiteName: checkSettingValue(settings, "siteName"), - LoginCaptcha: model.IsTrueVal(checkSettingValue(settings, "login_captcha")), - RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")), - ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")), - EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")), - Themes: checkSettingValue(settings, "themes"), - DefaultTheme: checkSettingValue(settings, "defaultTheme"), - HomepageViewMethod: checkSettingValue(settings, "home_view_method"), - ShareViewMethod: checkSettingValue(settings, "share_view_method"), - Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")), - User: userRes, - ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"), - CaptchaType: checkSettingValue(settings, "captcha_type"), - TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"), - RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")), - AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")), - WopiExts: wopiExts, - }} - return res +// VolResponse VOL query response +type VolResponse struct { + Signature string `json:"signature"` + Content string `json:"content"` } diff --git a/pkg/serializer/setting_test.go b/pkg/serializer/setting_test.go deleted file mode 100644 index 680edb6f..00000000 --- a/pkg/serializer/setting_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestCheckSettingValue(t *testing.T) { - asserts := assert.New(t) - - asserts.Equal("", checkSettingValue(map[string]string{}, "key")) - asserts.Equal("123", checkSettingValue(map[string]string{"key": "123"}, "key")) -} - -func TestBuildSiteConfig(t *testing.T) { - asserts := assert.New(t) - - res := BuildSiteConfig(map[string]string{"not exist": ""}, &model.User{}, nil) - asserts.Equal("", res.Data.(SiteConfig).SiteName) - - res = BuildSiteConfig(map[string]string{"siteName": "123"}, &model.User{}, nil) - asserts.Equal("123", res.Data.(SiteConfig).SiteName) - - // 非空用户 - res = BuildSiteConfig(map[string]string{"qq_login": "1"}, &model.User{ - Model: gorm.Model{ - ID: 5, - }, - }, nil) - asserts.Len(res.Data.(SiteConfig).User.ID, 4) -} - -func TestBuildTaskList(t *testing.T) { - asserts := assert.New(t) - tasks := []model.Task{{}} - - res := BuildTaskList(tasks, 1) - asserts.NotNil(res) -} diff --git a/pkg/serializer/share.go b/pkg/serializer/share.go deleted file mode 100644 index 94da4c6e..00000000 --- a/pkg/serializer/share.go +++ /dev/null @@ -1,135 +0,0 @@ -package serializer - -import ( - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" -) - -// Share 分享信息序列化 -type Share struct { - Key string `json:"key"` - Locked bool `json:"locked"` - IsDir bool `json:"is_dir"` - CreateDate time.Time `json:"create_date,omitempty"` - Downloads int `json:"downloads"` - Views int `json:"views"` - Expire int64 `json:"expire"` - Preview bool `json:"preview"` - Creator *shareCreator `json:"creator,omitempty"` - Source *shareSource `json:"source,omitempty"` -} - -type shareCreator struct { - Key string `json:"key"` - Nick string `json:"nick"` - GroupName string `json:"group_name"` -} - -type shareSource struct { - Name string `json:"name"` - Size uint64 `json:"size"` -} - -// myShareItem 我的分享列表条目 -type myShareItem struct { - Key string `json:"key"` - IsDir bool `json:"is_dir"` - Password string `json:"password"` - CreateDate time.Time `json:"create_date,omitempty"` - Downloads int `json:"downloads"` - RemainDownloads int `json:"remain_downloads"` - Views int `json:"views"` - Expire int64 `json:"expire"` - Preview bool `json:"preview"` - Source *shareSource `json:"source,omitempty"` -} - -// BuildShareList 构建我的分享列表响应 -func BuildShareList(shares []model.Share, total int) Response { - res := make([]myShareItem, 0, total) - now := time.Now().Unix() - for i := 0; i < len(shares); i++ { - item := myShareItem{ - Key: hashid.HashID(shares[i].ID, hashid.ShareID), - IsDir: shares[i].IsDir, - Password: shares[i].Password, - CreateDate: shares[i].CreatedAt, - Downloads: shares[i].Downloads, - Views: shares[i].Views, - Preview: shares[i].PreviewEnabled, - Expire: -1, - RemainDownloads: shares[i].RemainDownloads, - } - if shares[i].Expires != nil { - item.Expire = shares[i].Expires.Unix() - now - if item.Expire == 0 { - item.Expire = 0 - } - } - if shares[i].File.ID != 0 { - item.Source = &shareSource{ - Name: shares[i].File.Name, - Size: shares[i].File.Size, - } - } else if shares[i].Folder.ID != 0 { - item.Source = &shareSource{ - Name: shares[i].Folder.Name, - } - } - - res = append(res, item) - } - - return Response{Data: map[string]interface{}{ - "total": total, - "items": res, - }} -} - -// BuildShareResponse 构建获取分享信息响应 -func BuildShareResponse(share *model.Share, unlocked bool) Share { - creator := share.Creator() - resp := Share{ - Key: hashid.HashID(share.ID, hashid.ShareID), - Locked: !unlocked, - Creator: &shareCreator{ - Key: hashid.HashID(creator.ID, hashid.UserID), - Nick: creator.Nick, - GroupName: creator.Group.Name, - }, - CreateDate: share.CreatedAt, - } - - // 未解锁时只返回基本信息 - if !unlocked { - return resp - } - - resp.IsDir = share.IsDir - resp.Downloads = share.Downloads - resp.Views = share.Views - resp.Preview = share.PreviewEnabled - - if share.Expires != nil { - resp.Expire = share.Expires.Unix() - time.Now().Unix() - } - - if share.IsDir { - source := share.SourceFolder() - resp.Source = &shareSource{ - Name: source.Name, - Size: 0, - } - } else { - source := share.SourceFile() - resp.Source = &shareSource{ - Name: source.Name, - Size: source.Size, - } - } - - return resp - -} diff --git a/pkg/serializer/share_test.go b/pkg/serializer/share_test.go deleted file mode 100644 index 72feb0cc..00000000 --- a/pkg/serializer/share_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package serializer - -import ( - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestBuildShareList(t *testing.T) { - asserts := assert.New(t) - timeNow := time.Now() - - shares := []model.Share{ - { - Expires: &timeNow, - File: model.File{ - Model: gorm.Model{ID: 1}, - }, - }, - { - Folder: model.Folder{ - Model: gorm.Model{ID: 1}, - }, - }, - } - - res := BuildShareList(shares, 2) - asserts.Equal(0, res.Code) -} - -func TestBuildShareResponse(t *testing.T) { - asserts := assert.New(t) - - // 未解锁 - { - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - } - res := BuildShareResponse(share, false) - asserts.EqualValues(0, res.Downloads) - asserts.True(res.Locked) - asserts.NotNil(res.Creator) - } - - // 已解锁,非目录 - { - expires := time.Now().Add(time.Duration(10) * time.Second) - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - Expires: &expires, - File: model.File{ - Model: gorm.Model{ID: 1}, - }, - } - res := BuildShareResponse(share, true) - asserts.EqualValues(1, res.Downloads) - asserts.False(res.Locked) - asserts.NotEmpty(res.Expire) - asserts.NotNil(res.Creator) - } - - // 已解锁,是目录 - { - expires := time.Now().Add(time.Duration(10) * time.Second) - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - Expires: &expires, - Folder: model.Folder{ - Model: gorm.Model{ID: 1}, - }, - IsDir: true, - } - res := BuildShareResponse(share, true) - asserts.EqualValues(1, res.Downloads) - asserts.False(res.Locked) - asserts.NotEmpty(res.Expire) - asserts.NotNil(res.Creator) - } -} diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go deleted file mode 100644 index 04d56d3d..00000000 --- a/pkg/serializer/slave.go +++ /dev/null @@ -1,68 +0,0 @@ -package serializer - -import ( - "crypto/sha1" - "encoding/gob" - "fmt" - - model "github.com/cloudreve/Cloudreve/v3/models" -) - -// RemoteDeleteRequest 远程策略删除接口请求正文 -type RemoteDeleteRequest struct { - Files []string `json:"files"` -} - -// ListRequest 远程策略列文件请求正文 -type ListRequest struct { - Path string `json:"path"` - Recursive bool `json:"recursive"` -} - -// NodePingReq 从机节点Ping请求 -type NodePingReq struct { - SiteURL string `json:"site_url"` - SiteID string `json:"site_id"` - IsUpdate bool `json:"is_update"` - CredentialTTL int `json:"credential_ttl"` - Node *model.Node `json:"node"` -} - -// NodePingResp 从机节点Ping响应 -type NodePingResp struct { -} - -// SlaveAria2Call 从机有关Aria2的请求正文 -type SlaveAria2Call struct { - Task *model.Download `json:"task"` - GroupOptions map[string]interface{} `json:"group_options"` - Files []int `json:"files"` -} - -// SlaveTransferReq 从机中转任务创建请求 -type SlaveTransferReq struct { - Src string `json:"src"` - Dst string `json:"dst"` - Policy *model.Policy `json:"policy"` -} - -// Hash 返回创建请求的唯一标识,保持创建请求幂等 -func (s *SlaveTransferReq) Hash(id string) string { - h := sha1.New() - h.Write([]byte(fmt.Sprintf("transfer-%s-%s-%s-%d", id, s.Src, s.Dst, s.Policy.ID))) - bs := h.Sum(nil) - return fmt.Sprintf("%x", bs) -} - -const ( - SlaveTransferSuccess = "success" - SlaveTransferFailed = "failed" -) - -type SlaveTransferResult struct { - Error string -} - -func init() { - gob.Register(SlaveTransferResult{}) -} diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go deleted file mode 100644 index 46b5d2d4..00000000 --- a/pkg/serializer/slave_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -func TestSlaveTransferReq_Hash(t *testing.T) { - a := assert.New(t) - s1 := &SlaveTransferReq{ - Src: "1", - Policy: &model.Policy{}, - } - s2 := &SlaveTransferReq{ - Src: "2", - Policy: &model.Policy{}, - } - a.NotEqual(s1.Hash("1"), s2.Hash("1")) -} diff --git a/pkg/serializer/upload.go b/pkg/serializer/upload.go index 4e7150b1..32699d7d 100644 --- a/pkg/serializer/upload.go +++ b/pkg/serializer/upload.go @@ -1,64 +1,6 @@ package serializer -import ( - "encoding/gob" - model "github.com/cloudreve/Cloudreve/v3/models" - "time" -) - -// UploadPolicy slave模式下传递的上传策略 -type UploadPolicy struct { - SavePath string `json:"save_path"` - FileName string `json:"file_name"` - AutoRename bool `json:"auto_rename"` - MaxSize uint64 `json:"max_size"` - AllowedExtension []string `json:"allowed_extension"` - CallbackURL string `json:"callback_url"` -} - -// UploadCredential 返回给客户端的上传凭证 -type UploadCredential struct { - SessionID string `json:"sessionID"` - ChunkSize uint64 `json:"chunkSize"` // 分块大小,0 为部分快 - Expires int64 `json:"expires"` // 上传凭证过期时间, Unix 时间戳 - UploadURLs []string `json:"uploadURLs,omitempty"` - Credential string `json:"credential,omitempty"` - UploadID string `json:"uploadID,omitempty"` - Callback string `json:"callback,omitempty"` // 回调地址 - Path string `json:"path,omitempty"` // 存储路径 - AccessKey string `json:"ak,omitempty"` - KeyTime string `json:"keyTime,omitempty"` // COS用有效期 - Policy string `json:"policy,omitempty"` - CompleteURL string `json:"completeURL,omitempty"` -} - -// UploadSession 上传会话 -type UploadSession struct { - Key string // 上传会话 GUID - UID uint // 发起者 - VirtualPath string // 用户文件路径,不含文件名 - Name string // 文件名 - Size uint64 // 文件大小 - SavePath string // 物理存储路径,包含物理文件名 - LastModified *time.Time // 可选的文件最后修改日期 - Policy model.Policy - Callback string // 回调 URL 地址 - CallbackSecret string // 回调 URL - UploadURL string - UploadID string - Credential string -} - -// UploadCallback 上传回调正文 -type UploadCallback struct { - PicInfo string `json:"pic_info"` -} - // GeneralUploadCallbackFailed 存储策略上传回调失败响应 type GeneralUploadCallbackFailed struct { Error string `json:"error"` } - -func init() { - gob.Register(UploadSession{}) -} diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go deleted file mode 100644 index 142f424a..00000000 --- a/pkg/serializer/user.go +++ /dev/null @@ -1,156 +0,0 @@ -package serializer - -import ( - "fmt" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/duo-labs/webauthn/webauthn" - "time" -) - -// CheckLogin 检查登录 -func CheckLogin() Response { - return Response{ - Code: CodeCheckLogin, - Msg: "Login required", - } -} - -// User 用户序列化器 -type User struct { - ID string `json:"id"` - Email string `json:"user_name"` - Nickname string `json:"nickname"` - Status int `json:"status"` - Avatar string `json:"avatar"` - CreatedAt time.Time `json:"created_at"` - PreferredTheme string `json:"preferred_theme"` - Anonymous bool `json:"anonymous"` - Group group `json:"group"` - Tags []tag `json:"tags"` -} - -type group struct { - ID uint `json:"id"` - Name string `json:"name"` - AllowShare bool `json:"allowShare"` - AllowRemoteDownload bool `json:"allowRemoteDownload"` - AllowArchiveDownload bool `json:"allowArchiveDownload"` - ShareDownload bool `json:"shareDownload"` - CompressEnabled bool `json:"compress"` - WebDAVEnabled bool `json:"webdav"` - SourceBatchSize int `json:"sourceBatch"` - AdvanceDelete bool `json:"advanceDelete"` - AllowWebDAVProxy bool `json:"allowWebDAVProxy"` -} - -type tag struct { - ID string `json:"id"` - Name string `json:"name"` - Icon string `json:"icon"` - Color string `json:"color"` - Type int `json:"type"` - Expression string `json:"expression"` -} - -type storage struct { - Used uint64 `json:"used"` - Free uint64 `json:"free"` - Total uint64 `json:"total"` -} - -// WebAuthnCredentials 外部验证器凭证 -type WebAuthnCredentials struct { - ID []byte `json:"id"` - FingerPrint string `json:"fingerprint"` -} - -// BuildWebAuthnList 构建设置页面凭证列表 -func BuildWebAuthnList(credentials []webauthn.Credential) []WebAuthnCredentials { - res := make([]WebAuthnCredentials, 0, len(credentials)) - for _, v := range credentials { - credential := WebAuthnCredentials{ - ID: v.ID, - FingerPrint: fmt.Sprintf("% X", v.Authenticator.AAGUID), - } - res = append(res, credential) - } - - return res -} - -// BuildUser 序列化用户 -func BuildUser(user model.User) User { - tags, _ := model.GetTagsByUID(user.ID) - return User{ - ID: hashid.HashID(user.ID, hashid.UserID), - Email: user.Email, - Nickname: user.Nick, - Status: user.Status, - Avatar: user.Avatar, - CreatedAt: user.CreatedAt, - PreferredTheme: user.OptionsSerialized.PreferredTheme, - Anonymous: user.IsAnonymous(), - Group: group{ - ID: user.GroupID, - Name: user.Group.Name, - AllowShare: user.Group.ShareEnabled, - AllowRemoteDownload: user.Group.OptionsSerialized.Aria2, - AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownload, - ShareDownload: user.Group.OptionsSerialized.ShareDownload, - CompressEnabled: user.Group.OptionsSerialized.ArchiveTask, - WebDAVEnabled: user.Group.WebDAVEnabled, - AllowWebDAVProxy: user.Group.OptionsSerialized.WebDAVProxy, - SourceBatchSize: user.Group.OptionsSerialized.SourceBatchSize, - AdvanceDelete: user.Group.OptionsSerialized.AdvanceDelete, - }, - Tags: buildTagRes(tags), - } -} - -// BuildUserResponse 序列化用户响应 -func BuildUserResponse(user model.User) Response { - return Response{ - Data: BuildUser(user), - } -} - -// BuildUserStorageResponse 序列化用户存储概况响应 -func BuildUserStorageResponse(user model.User) Response { - total := user.Group.MaxStorage - storageResp := storage{ - Used: user.Storage, - Free: total - user.Storage, - Total: total, - } - - if total < user.Storage { - storageResp.Free = 0 - } - - return Response{ - Data: storageResp, - } -} - -// buildTagRes 构建标签列表 -func buildTagRes(tags []model.Tag) []tag { - res := make([]tag, 0, len(tags)) - for i := 0; i < len(tags); i++ { - newTag := tag{ - ID: hashid.HashID(tags[i].ID, hashid.TagID), - Name: tags[i].Name, - Icon: tags[i].Icon, - Color: tags[i].Color, - Type: tags[i].Type, - } - if newTag.Type != 0 { - newTag.Expression = tags[i].Expression - - } - res = append(res, newTag) - } - - return res -} diff --git a/pkg/serializer/user_test.go b/pkg/serializer/user_test.go deleted file mode 100644 index 29421861..00000000 --- a/pkg/serializer/user_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package serializer - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/duo-labs/webauthn/webauthn" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestBuildUser(t *testing.T) { - asserts := assert.New(t) - user := model.User{} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - res := BuildUser(user) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(res) - -} - -func TestBuildUserResponse(t *testing.T) { - asserts := assert.New(t) - user := model.User{} - res := BuildUserResponse(user) - asserts.NotNil(res) -} - -func TestBuildUserStorageResponse(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_0", uint64(0), 0) - - { - user := model.User{ - Storage: 0, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(0), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(10), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 6, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(6), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(4), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 20, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(20), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(0), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 6, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(6), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(4), res.Data.(storage).Free) - } -} - -func TestBuildTagRes(t *testing.T) { - asserts := assert.New(t) - tags := []model.Tag{ - { - Type: 0, - Expression: "exp", - }, - { - Type: 1, - Expression: "exp", - }, - } - res := buildTagRes(tags) - asserts.Len(res, 2) - asserts.Equal("", res[0].Expression) - asserts.Equal("exp", res[1].Expression) -} - -func TestBuildWebAuthnList(t *testing.T) { - asserts := assert.New(t) - credentials := []webauthn.Credential{{}} - res := BuildWebAuthnList(credentials) - asserts.Len(res, 1) -} diff --git a/pkg/sessionstore/kv.go b/pkg/sessionstore/kv.go index 193d5c68..7577abb3 100644 --- a/pkg/sessionstore/kv.go +++ b/pkg/sessionstore/kv.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/base32" "encoding/gob" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "net/http" @@ -76,7 +76,7 @@ func (s *kvStore) New(r *http.Request, name string) (*sessions.Session, error) { func (s *kvStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { // Marked for deletion. if session.Options.MaxAge <= 0 { - if err := s.store.Delete([]string{session.ID}, s.prefix); err != nil { + if err := s.store.Delete(s.prefix, session.ID); err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) diff --git a/pkg/sessionstore/sessionstore.go b/pkg/sessionstore/sessionstore.go index 3b1c302e..052a4956 100644 --- a/pkg/sessionstore/sessionstore.go +++ b/pkg/sessionstore/sessionstore.go @@ -1,7 +1,7 @@ package sessionstore import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" "github.com/gin-contrib/sessions" ) diff --git a/pkg/setting/adapters.go b/pkg/setting/adapters.go new file mode 100644 index 00000000..47d80fa6 --- /dev/null +++ b/pkg/setting/adapters.go @@ -0,0 +1,133 @@ +package setting + +import ( + "context" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/cache" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/samber/lo" + "os" + "strings" +) + +const ( + KvSettingPrefix = "setting_" + EnvSettingOverwritePrefix = "CR_SETTING_" +) + +// SettingStoreAdapter chains a setting get operation, if current adapter cannot locate setting value, +// it will invoke next adapter until last one. +type SettingStoreAdapter interface { + // Get a string setting from underlying store, if setting not found, default + // value will be used. + Get(ctx context.Context, name string, defaultVal any) any +} + +// NewDbSettingStore creates an adapter using DB setting store. Only string type value is supported. +func NewDbSettingStore(c inventory.SettingClient, next SettingStoreAdapter) SettingStoreAdapter { + return &dbSettingStore{ + c: c, + next: next, + } +} + +// NewKvSettingStore creates an adapter using KV setting store. +func NewKvSettingStore(c cache.Driver, next SettingStoreAdapter) SettingStoreAdapter { + return &kvSettingStore{ + kv: c, + next: next, + } +} + +// NewConfSettingStore creates an adapter using static config overwrite. Only string type value is supported. +func NewConfSettingStore(c conf.ConfigProvider, next SettingStoreAdapter) SettingStoreAdapter { + return &staticSettingStore{ + settings: c.OptionOverwrite(), + next: next, + } +} + +// NewDbDefaultStore creates an adapter that always returns default setting value defined in inventory.DefaultSettings. +// Only string type value is supported. +func NewDbDefaultStore(next SettingStoreAdapter) SettingStoreAdapter { + return &staticSettingStore{ + settings: lo.MapValues(inventory.DefaultSettings, func(v string, _ string) any { + return v + }), + next: next, + } +} + +// NewEnvOverwriteStore creates an adapter that always returns overrided setting value defined in environment variables. +func NewEnvOverrideStore(next SettingStoreAdapter, l logging.Logger) SettingStoreAdapter { + allEnv := os.Environ() + defaults := make(map[string]any) + for _, env := range allEnv { + kv := strings.SplitN(env, "=", 2) + if strings.HasPrefix(kv[0], EnvSettingOverwritePrefix) { + key := strings.TrimPrefix(kv[0], EnvSettingOverwritePrefix) + defaults[key] = kv[1] + l.Info("Override setting %q with value %q from environment", key, kv[1]) + } + } + + return &staticSettingStore{ + settings: defaults, + next: next, + } +} + +type dbSettingStore struct { + c inventory.SettingClient + next SettingStoreAdapter +} + +func (s *dbSettingStore) Get(ctx context.Context, name string, defaultVal any) any { + if val, err := s.c.Get(ctx, name); err == nil { + return val + } + + if s.next != nil { + return s.next.Get(ctx, name, defaultVal) + } + + return defaultVal +} + +type kvSettingStore struct { + kv cache.Driver + next SettingStoreAdapter +} + +func (s *kvSettingStore) Get(ctx context.Context, name string, defaultVal any) any { + if val, exist := s.kv.Get(KvSettingPrefix + name); exist { + return val + } + + if s.next != nil { + nextVal := s.next.Get(ctx, name, defaultVal) + // cache setting value + s.kv.Set(KvSettingPrefix+name, nextVal, 0) + return nextVal + } + + return defaultVal +} + +type staticSettingStore struct { + settings map[string]any + next SettingStoreAdapter +} + +func (s *staticSettingStore) Get(ctx context.Context, name string, defaultVal any) any { + if val, ok := s.settings[name]; ok { + return val + } + + if s.next != nil { + return s.next.Get(ctx, name, defaultVal) + } + + return defaultVal +} diff --git a/pkg/setting/provider.go b/pkg/setting/provider.go new file mode 100644 index 00000000..4cc4aa07 --- /dev/null +++ b/pkg/setting/provider.go @@ -0,0 +1,784 @@ +package setting + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/pkg/auth/requestinfo" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" +) + +type ( + // Provider provides strong type setting access. + Provider interface { + // Site basic information + SiteBasic(ctx context.Context) *SiteBasic + // PWA related settings + PWA(ctx context.Context) *PWASetting + // RegisterEnabled returns true if public sign-up is enabled. + RegisterEnabled(ctx context.Context) bool + // AuthnEnabled returns true if Webauthn is enabled. + AuthnEnabled(ctx context.Context) bool + // RegCaptchaEnabled returns true if registration captcha is enabled. + RegCaptchaEnabled(ctx context.Context) bool + // LoginCaptchaEnabled returns true if login captcha is enabled. + LoginCaptchaEnabled(ctx context.Context) bool + // ForgotPasswordCaptchaEnabled returns true if forgot password captcha is enabled. + ForgotPasswordCaptchaEnabled(ctx context.Context) bool + // CaptchaType returns the type of captcha used. + CaptchaType(ctx context.Context) CaptchaType + // ReCaptcha returns the Google reCaptcha settings. + ReCaptcha(ctx context.Context) *ReCaptcha + // TcCaptcha returns the Tencent Cloud Captcha settings. + TcCaptcha(ctx context.Context) *TcCaptcha + // TurnstileCaptcha returns the Cloudflare Turnstile settings. + TurnstileCaptcha(ctx context.Context) *Turnstile + // EmailActivationEnabled returns true if email activation is required. + EmailActivationEnabled(ctx context.Context) bool + // DefaultGroup returns the default group ID for new users. + DefaultGroup(ctx context.Context) int + // SMTP returns the SMTP settings. + SMTP(ctx context.Context) *SMTP + // SiteURL returns the basic URL. + SiteURL(ctx context.Context) *url.URL + // SecretKey returns the secret key for general signature. + SecretKey(ctx context.Context) string + // ActivationEmailTemplate returns the email template for activation. + ActivationEmailTemplate(ctx context.Context) []EmailTemplate + // ResetEmailTemplate returns the email template for reset password. + ResetEmailTemplate(ctx context.Context) []EmailTemplate + // TokenAuth returns token based auth related settings. + TokenAuth(ctx context.Context) *TokenAuth + // HashIDSalt returns the salt used for hash ID generation. + HashIDSalt(ctx context.Context) string + // DBFS returns the DBFS related settings. + DBFS(ctx context.Context) *DBFS + // MaxBatchedFile returns the maximum number of files in a batch operation. + MaxBatchedFile(ctx context.Context) int + // UploadSessionTTL returns the TTL of upload session. + UploadSessionTTL(ctx context.Context) time.Duration + // MaxOnlineEditSize returns the maximum size of online editing. + MaxOnlineEditSize(ctx context.Context) int64 + // SlaveRequestSignTTL returns the TTL of slave request signature. + SlaveRequestSignTTL(ctx context.Context) int + // ChunkRetryLimit returns the maximum number of chunk retries. + ChunkRetryLimit(ctx context.Context) int + // UseChunkBuffer returns true if chunk buffer is enabled. + UseChunkBuffer(ctx context.Context) bool + // Queue returns the queue settings. + Queue(ctx context.Context, queueType QueueType) *QueueSetting + // EntityUrlCacheMargin returns the safe margin of entity URL cache. URL cache will + // expire in (EntityUrlValidDuration - EntityUrlCacheMargin). + EntityUrlCacheMargin(ctx context.Context) int + // EntityUrlValidDuration returns the valid duration of entity URL. + EntityUrlValidDuration(ctx context.Context) time.Duration + // PublicResourceMaxAge returns the max age of public resources. + PublicResourceMaxAge(ctx context.Context) int + // MediaMetaEnabled returns true if media meta is enabled. + MediaMetaEnabled(ctx context.Context) bool + // MediaMetaExifEnabled returns true if media meta exif is enabled. + MediaMetaExifEnabled(ctx context.Context) bool + // MediaMetaExifSizeLimit returns the size limit of media meta exif. first return value is for local sources; + // second return value is for remote sources. + MediaMetaExifSizeLimit(ctx context.Context) (int64, int64) + // MediaMetaExifBruteForce returns true if media meta exif brute force search is enabled. + MediaMetaExifBruteForce(ctx context.Context) bool + // MediaMetaMusicEnabled returns true if media meta audio is enabled. + MediaMetaMusicEnabled(ctx context.Context) bool + // MediaMetaMusicSizeLimit returns the size limit of media meta audio. first return value is for local sources; + MediaMetaMusicSizeLimit(ctx context.Context) (int64, int64) + // MediaMetaFFProbeEnabled returns true if media meta ffprobe is enabled. + MediaMetaFFProbeEnabled(ctx context.Context) bool + // MediaMetaFFProbeSizeLimit returns the size limit of media meta ffprobe. first return value is for local sources; + MediaMetaFFProbeSizeLimit(ctx context.Context) (int64, int64) + // MediaMetaFFProbePath returns the path of ffprobe executable. + MediaMetaFFProbePath(ctx context.Context) string + // ThumbSize returns the size limit of thumbnails. + ThumbSize(ctx context.Context) (int, int) + // ThumbEncode returns the thumbnail encoding settings. + ThumbEncode(ctx context.Context) *ThumbEncode + // BuiltinThumbGeneratorEnabled returns true if builtin thumb generator is enabled. + BuiltinThumbGeneratorEnabled(ctx context.Context) bool + // BuiltinThumbMaxSize returns the maximum size of builtin thumb generator. + BuiltinThumbMaxSize(ctx context.Context) int64 + // TempPath returns the path of temporary directory. + TempPath(ctx context.Context) string + // ThumbEntitySuffix returns the suffix of entity thumbnails. + ThumbEntitySuffix(ctx context.Context) string + // ThumbSlaveSidecarSuffix returns the suffix of slave sidecar thumbnails. + ThumbSlaveSidecarSuffix(ctx context.Context) string + // ThumbGCAfterGen returns true if force GC is invoked after thumb generation. + ThumbGCAfterGen(ctx context.Context) bool + // FFMpegPath returns the path of ffmpeg executable. + FFMpegPath(ctx context.Context) string + // FFMpegThumbGeneratorEnabled returns true if ffmpeg thumb generator is enabled. + FFMpegThumbGeneratorEnabled(ctx context.Context) bool + // FFMpegThumbExts returns the supported extensions of ffmpeg thumb generator. + FFMpegThumbExts(ctx context.Context) []string + // FFMpegThumbSeek returns the seek time of ffmpeg thumb generator. + FFMpegThumbSeek(ctx context.Context) string + // FFMpegThumbMaxSize returns the maximum size of ffmpeg thumb generator. + FFMpegThumbMaxSize(ctx context.Context) int64 + // VipsThumbGeneratorEnabled returns true if vips thumb generator is enabled. + VipsThumbGeneratorEnabled(ctx context.Context) bool + // VipsThumbExts returns the supported extensions of vips thumb generator. + VipsThumbExts(ctx context.Context) []string + // VipsThumbMaxSize returns the maximum size of vips thumb generator. + VipsThumbMaxSize(ctx context.Context) int64 + // VipsPath returns the path of vips executable. + VipsPath(ctx context.Context) string + // LibreOfficeThumbGeneratorEnabled returns true if libreoffice thumb generator is enabled. + LibreOfficeThumbGeneratorEnabled(ctx context.Context) bool + // LibreOfficeThumbExts returns the supported extensions of libreoffice thumb generator. + LibreOfficeThumbExts(ctx context.Context) []string + // LibreOfficeThumbMaxSize returns the maximum size of libreoffice thumb generator. + LibreOfficeThumbMaxSize(ctx context.Context) int64 + // LibreOfficePath returns the path of libreoffice executable. + LibreOfficePath(ctx context.Context) string + // MusicCoverThumbGeneratorEnabled returns true if music cover thumb generator is enabled. + MusicCoverThumbGeneratorEnabled(ctx context.Context) bool + // MusicCoverThumbMaxSize returns the maximum size of music cover thumb generator. + MusicCoverThumbMaxSize(ctx context.Context) int64 + // MusicCoverThumbExts returns the supported extensions of music cover thumb generator. + MusicCoverThumbExts(ctx context.Context) []string + // Cron returns the crontab settings. + Cron(ctx context.Context, t CronType) string + // Theme returns the theme settings. + Theme(ctx context.Context) *Theme + // Logo returns the logo settings. + Logo(ctx context.Context) *Logo + // LegalDocuments returns the legal documents settings. + LegalDocuments(ctx context.Context) *LegalDocuments + // Captcha returns the captcha settings. + Captcha(ctx context.Context) *Captcha + // ExplorerFrontendSettings returns the explorer frontend settings. + ExplorerFrontendSettings(ctx context.Context) *ExplorerFrontendSettings + // SearchCategoryQuery returns the search category query. + SearchCategoryQuery(ctx context.Context, category SearchCategory) string + // EmojiPresets returns the emoji presets used in file icon customization. + EmojiPresets(ctx context.Context) string + // MapSetting returns the EXIF GPS map related settings. + MapSetting(ctx context.Context) *MapSetting + // FolderPropsCacheTTL returns the cache TTL of folder summary. + FolderPropsCacheTTL(ctx context.Context) int + // FileViewers returns the file viewers settings. + FileViewers(ctx context.Context) []ViewerGroup + // ViewerSessionTTL returns the TTL of viewer session. + ViewerSessionTTL(ctx context.Context) int + // MimeMapping returns the extension to MIME mapping settings. + MimeMapping(ctx context.Context) string + // MaxParallelTransfer returns the maximum parallel transfer in workflows. + MaxParallelTransfer(ctx context.Context) int + // ArchiveDownloadSessionTTL returns the TTL of archive download session. + ArchiveDownloadSessionTTL(ctx context.Context) int + // AppSetting returns the app related settings. + AppSetting(ctx context.Context) *AppSetting + // Avatar returns the avatar settings. + Avatar(ctx context.Context) *Avatar + // AvatarProcess returns the avatar process settings. + AvatarProcess(ctx context.Context) *AvatarProcess + // UseFirstSiteUrl returns the first site URL. + AllSiteURLs(ctx context.Context) []*url.URL + } + UseFirstSiteUrlCtxKey = struct{} +) + +// NewProvider creates a new setting provider. +func NewProvider(root SettingStoreAdapter) Provider { + return &settingProvider{ + adapterChain: root, + } +} + +const ( + stringListDefault = "DEFAULT" + stringListDefaultSeparator = "," +) + +var defaultBoolSet = &boolset.BooleanSet{} + +type ( + SiteHostAllowListGetter interface { + AllowedHost() []string + } + settingProvider struct { + adapterChain SettingStoreAdapter + } +) + +func (s *settingProvider) License(ctx context.Context) string { + return s.getString(ctx, "license", "") +} + +func (s *settingProvider) AvatarProcess(ctx context.Context) *AvatarProcess { + return &AvatarProcess{ + Path: s.getString(ctx, "avatar_path", "avatar"), + MaxFileSize: s.getInt64(ctx, "avatar_size", 4194304), + MaxWidth: s.getInt(ctx, "avatar_size_l", 200), + } +} + +func (s *settingProvider) Avatar(ctx context.Context) *Avatar { + return &Avatar{ + Gravatar: s.getString(ctx, "gravatar_server", ""), + Path: s.getString(ctx, "avatar_path", "avatar"), + } +} + +func (s *settingProvider) FileViewers(ctx context.Context) []ViewerGroup { + raw := s.getString(ctx, "file_viewers", "[]") + var viewers []ViewerGroup + if err := json.Unmarshal([]byte(raw), &viewers); err != nil { + return []ViewerGroup{} + } + + return viewers +} + +func (s *settingProvider) AppSetting(ctx context.Context) *AppSetting { + return &AppSetting{ + Promotion: s.getBoolean(ctx, "show_app_promotion", false), + } +} + +func (s *settingProvider) MaxParallelTransfer(ctx context.Context) int { + return s.getInt(ctx, "max_parallel_transfer", 4) +} + +func (s *settingProvider) ArchiveDownloadSessionTTL(ctx context.Context) int { + return s.getInt(ctx, "archive_timeout", 20) +} + +func (s *settingProvider) ViewerSessionTTL(ctx context.Context) int { + return s.getInt(ctx, "viewer_session_timeout", 36000) +} + +func (s *settingProvider) MapSetting(ctx context.Context) *MapSetting { + return &MapSetting{ + Provider: MapProvider(s.getString(ctx, "map_provider", "openstreetmap")), + GoogleTileType: MapGoogleTileType(s.getString(ctx, "map_google_tile_type", "roadmap")), + } +} + +func (s *settingProvider) MimeMapping(ctx context.Context) string { + return s.getString(ctx, "mime_mapping", "{}") +} + +func (s *settingProvider) Logo(ctx context.Context) *Logo { + return &Logo{ + Normal: s.getString(ctx, "site_logo", "/static/img/logo.svg"), + Light: s.getString(ctx, "site_logo_light", "/static/img/logo_light.svg"), + } +} + +func (s *settingProvider) ExplorerFrontendSettings(ctx context.Context) *ExplorerFrontendSettings { + return &ExplorerFrontendSettings{ + Icons: s.getString(ctx, "explorer_icons", "[]"), + } +} + +func (s *settingProvider) SearchCategoryQuery(ctx context.Context, category SearchCategory) string { + return s.getString(ctx, fmt.Sprintf("explorer_category_%s_query", category), "") +} + +func (s *settingProvider) Captcha(ctx context.Context) *Captcha { + return &Captcha{ + Height: s.getInt(ctx, "captcha_height", 60), + Width: s.getInt(ctx, "captcha_width", 240), + Mode: CaptchaMode(s.getInt(ctx, "captcha_mode", int(CaptchaModeNumberAlphabet))), + ComplexOfNoiseText: s.getInt(ctx, "captcha_ComplexOfNoiseText", 0), + ComplexOfNoiseDot: s.getInt(ctx, "captcha_ComplexOfNoiseDot", 0), + IsShowHollowLine: s.getBoolean(ctx, "captcha_IsShowHollowLine", false), + IsShowNoiseDot: s.getBoolean(ctx, "captcha_IsShowNoiseDot", false), + IsShowNoiseText: s.getBoolean(ctx, "captcha_IsShowNoiseText", false), + IsShowSlimeLine: s.getBoolean(ctx, "captcha_IsShowSlimeLine", false), + IsShowSineLine: s.getBoolean(ctx, "captcha_IsShowSineLine", false), + Length: s.getInt(ctx, "captcha_CaptchaLen", 6), + } +} + +func (s *settingProvider) LegalDocuments(ctx context.Context) *LegalDocuments { + return &LegalDocuments{ + PrivacyPolicy: s.getString(ctx, "privacy_policy_url", ""), + TermsOfService: s.getString(ctx, "tos_url", ""), + } +} + +func (s *settingProvider) FolderPropsCacheTTL(ctx context.Context) int { + return s.getInt(ctx, "folder_props_timeout", 300) +} + +func (s *settingProvider) EmojiPresets(ctx context.Context) string { + return s.getString(ctx, "emojis", "{}") +} + +func (s *settingProvider) Theme(ctx context.Context) *Theme { + return &Theme{ + Themes: s.getString(ctx, "theme_options", "{}"), + DefaultTheme: s.getString(ctx, "defaultTheme", ""), + } +} + +func (s *settingProvider) Cron(ctx context.Context, t CronType) string { + return s.getString(ctx, "cron_"+string(t), "@hourly") +} + +func (s *settingProvider) BuiltinThumbGeneratorEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_builtin_enabled", true) +} + +func (s *settingProvider) BuiltinThumbMaxSize(ctx context.Context) int64 { + return s.getInt64(ctx, "thumb_builtin_max_size", 78643200) +} + +func (s *settingProvider) MusicCoverThumbGeneratorEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_music_cover_enabled", true) +} + +func (s *settingProvider) MusicCoverThumbMaxSize(ctx context.Context) int64 { + return s.getInt64(ctx, "thumb_music_cover_max_size", 1073741824) +} + +func (s *settingProvider) MusicCoverThumbExts(ctx context.Context) []string { + return s.getStringList(ctx, "thumb_music_cover_exts", []string{}) +} + +func (s *settingProvider) FFMpegPath(ctx context.Context) string { + return s.getString(ctx, "thumb_ffmpeg_path", "ffmpeg") +} + +func (s *settingProvider) FFMpegThumbGeneratorEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_ffmpeg_enabled", false) +} + +func (s *settingProvider) FFMpegThumbExts(ctx context.Context) []string { + return s.getStringList(ctx, "thumb_ffmpeg_exts", []string{}) +} + +func (s *settingProvider) FFMpegThumbSeek(ctx context.Context) string { + return s.getString(ctx, "thumb_ffmpeg_seek", "00:00:01.00") +} + +func (s *settingProvider) FFMpegThumbMaxSize(ctx context.Context) int64 { + return s.getInt64(ctx, "thumb_ffmpeg_max_size", 10737418240) +} + +func (s *settingProvider) VipsThumbGeneratorEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_vips_enabled", false) +} + +func (s *settingProvider) VipsThumbMaxSize(ctx context.Context) int64 { + return s.getInt64(ctx, "thumb_vips_max_size", 78643200) +} + +func (s *settingProvider) VipsThumbExts(ctx context.Context) []string { + return s.getStringList(ctx, "thumb_vips_exts", []string{}) +} + +func (s *settingProvider) VipsPath(ctx context.Context) string { + return s.getString(ctx, "thumb_vips_path", "vips") +} + +func (s *settingProvider) LibreOfficeThumbGeneratorEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_libreoffice_enabled", false) +} + +func (s *settingProvider) LibreOfficeThumbMaxSize(ctx context.Context) int64 { + return s.getInt64(ctx, "thumb_libreoffice_max_size", 78643200) +} + +func (s *settingProvider) LibreOfficePath(ctx context.Context) string { + return s.getString(ctx, "thumb_libreoffice_path", "soffice") +} + +func (s *settingProvider) LibreOfficeThumbExts(ctx context.Context) []string { + return s.getStringList(ctx, "thumb_libreoffice_exts", []string{}) +} + +func (s *settingProvider) ThumbSize(ctx context.Context) (int, int) { + return s.getInt(ctx, "thumb_width", 400), s.getInt(ctx, "thumb_height", 300) +} + +func (s *settingProvider) ThumbEncode(ctx context.Context) *ThumbEncode { + return &ThumbEncode{ + Format: s.getString(ctx, "thumb_encode_method", "jpg"), + Quality: s.getInt(ctx, "thumb_encode_quality", 85), + } +} + +func (s *settingProvider) ThumbEntitySuffix(ctx context.Context) string { + return s.getString(ctx, "thumb_entity_suffix", "._thumb") +} + +func (s *settingProvider) ThumbSlaveSidecarSuffix(ctx context.Context) string { + return s.getString(ctx, "thumb_slave_sidecar_suffix", "._thumb_sidecar") +} + +func (s *settingProvider) ThumbGCAfterGen(ctx context.Context) bool { + return s.getBoolean(ctx, "thumb_gc_after_gen", false) +} + +func (s *settingProvider) TempPath(ctx context.Context) string { + return s.getString(ctx, "temp_path", "temp") +} + +func (s *settingProvider) MediaMetaFFProbePath(ctx context.Context) string { + return s.getString(ctx, "media_meta_ffprobe_path", "ffprobe") +} + +func (s *settingProvider) MediaMetaFFProbeSizeLimit(ctx context.Context) (int64, int64) { + return s.getInt64(ctx, "media_meta_ffprobe_size_local", 0), s.getInt64(ctx, "media_meta_ffprobe_size_remote", 0) +} + +func (s *settingProvider) MediaMetaFFProbeEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "media_meta_ffprobe", true) +} + +func (s *settingProvider) MediaMetaMusicSizeLimit(ctx context.Context) (int64, int64) { + return s.getInt64(ctx, "media_meta_music_size_local", 0), s.getInt64(ctx, "media_meta_music_size_remote", 0) +} + +func (s *settingProvider) MediaMetaMusicEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "media_meta_music", true) +} + +func (s *settingProvider) MediaMetaExifBruteForce(ctx context.Context) bool { + return s.getBoolean(ctx, "media_meta_exif_brute_force", false) +} + +func (s *settingProvider) MediaMetaExifSizeLimit(ctx context.Context) (int64, int64) { + return s.getInt64(ctx, "media_meta_exif_size_local", 0), s.getInt64(ctx, "media_meta_exif_size_remote", 0) +} + +func (s *settingProvider) MediaMetaExifEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "media_meta_exif", true) +} + +func (s *settingProvider) MediaMetaEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "media_meta", true) +} + +func (s *settingProvider) PublicResourceMaxAge(ctx context.Context) int { + return s.getInt(ctx, "public_resource_maxage", 0) +} + +func (s *settingProvider) EntityUrlCacheMargin(ctx context.Context) int { + return s.getInt(ctx, "entity_url_cache_margin", 600) +} + +func (s *settingProvider) EntityUrlValidDuration(ctx context.Context) time.Duration { + return time.Duration(s.getInt(ctx, "entity_url_default_ttl", 3600)) * time.Second +} + +func (s *settingProvider) Queue(ctx context.Context, queueType QueueType) *QueueSetting { + queueTypeStr := string(queueType) + return &QueueSetting{ + WorkerNum: s.getInt(ctx, "queue_"+queueTypeStr+"_worker_num,", 15), + MaxExecution: time.Duration(s.getInt(ctx, "queue_"+queueTypeStr+"_max_execution", 86400)) * time.Second, + BackoffFactor: s.getFloat64(ctx, "queue_"+queueTypeStr+"_backoff_factor", 4), + BackoffMaxDuration: time.Duration(s.getInt(ctx, "queue_"+queueTypeStr+"_backoff_max_duration", 3600)) * time.Second, + MaxRetry: s.getInt(ctx, "queue_"+queueTypeStr+"_max_retry", 5), + RetryDelay: time.Duration(s.getInt(ctx, "queue_"+queueTypeStr+"_retry_delay", 5)) * time.Second, + } +} + +func (s *settingProvider) UseChunkBuffer(ctx context.Context) bool { + return s.getBoolean(ctx, "use_temp_chunk_buffer", true) +} + +func (s *settingProvider) ChunkRetryLimit(ctx context.Context) int { + return s.getInt(ctx, "chunk_retries", 3) +} + +func (s *settingProvider) SlaveRequestSignTTL(ctx context.Context) int { + return s.getInt(ctx, "slave_api_timeout", 60) +} + +func (s *settingProvider) MaxOnlineEditSize(ctx context.Context) int64 { + return int64(s.getInt(ctx, "maxEditSize", 52428800)) +} + +func (s *settingProvider) UploadSessionTTL(ctx context.Context) time.Duration { + return time.Duration(s.getInt(ctx, "upload_session_timeout", 86400)) * time.Second +} + +func (s *settingProvider) MaxBatchedFile(ctx context.Context) int { + return s.getInt(ctx, "max_batched_file", 3000) +} + +func (s *settingProvider) DBFS(ctx context.Context) *DBFS { + return &DBFS{ + UseCursorPagination: s.getBoolean(ctx, "use_cursor_pagination", true), + MaxPageSize: s.getInt(ctx, "max_page_size", 2000), + MaxRecursiveSearchedFolder: s.getInt(ctx, "max_recursive_searched_folder", 65535), + UseSSEForSearch: s.getBoolean(ctx, "use_sse_for_search", false), + } +} + +func (s *settingProvider) HashIDSalt(ctx context.Context) string { + return s.getString(ctx, "hash_id_salt", "") +} + +func (s *settingProvider) TokenAuth(ctx context.Context) *TokenAuth { + return &TokenAuth{ + AccessTokenTTL: time.Duration(s.getInt(ctx, "access_token_ttl", 3600)) * time.Second, + RefreshTokenTTL: time.Duration(s.getInt(ctx, "refresh_token_ttl", 15552000)) * time.Second, + } +} + +func (s *settingProvider) ResetEmailTemplate(ctx context.Context) []EmailTemplate { + src := s.getString(ctx, "mail_reset_template", "[]") + var templates []EmailTemplate + if err := json.Unmarshal([]byte(src), &templates); err != nil { + return []EmailTemplate{} + } + + return templates +} + +func (s *settingProvider) ActivationEmailTemplate(ctx context.Context) []EmailTemplate { + src := s.getString(ctx, "mail_activation_template", "[]") + var templates []EmailTemplate + if err := json.Unmarshal([]byte(src), &templates); err != nil { + return []EmailTemplate{} + } + + return templates +} + +func (s *settingProvider) SecretKey(ctx context.Context) string { + return s.getString(ctx, "secret_key", "") +} + +func (s *settingProvider) AllSiteURLs(ctx context.Context) []*url.URL { + rawUrls := s.getStringList(ctx, "siteURL", []string{"http://localhost"}) + if len(rawUrls) == 0 { + rawUrls = []string{"http://localhost"} + } + + urls := make([]*url.URL, 0, len(rawUrls)) + for _, u := range rawUrls { + parsedURL, err := url.Parse(u) + if err != nil { + continue + } + urls = append(urls, parsedURL) + } + return urls +} + +func (s *settingProvider) SiteURL(ctx context.Context) *url.URL { + rawUrls := s.getStringList(ctx, "siteURL", []string{"http://localhost"}) + if len(rawUrls) == 0 { + rawUrls = []string{"http://localhost"} + } + + urls := make([]*url.URL, 0, len(rawUrls)) + for _, u := range rawUrls { + parsedURL, err := url.Parse(u) + if err != nil { + continue + } + urls = append(urls, parsedURL) + } + + reqInfo := requestinfo.RequestInfoFromContext(ctx) + _, useFirst := ctx.Value(UseFirstSiteUrlCtxKey{}).(bool) + if !useFirst && reqInfo != nil && reqInfo.Host != "" { + for _, u := range urls { + if (u.Host) == reqInfo.Host { + return u + } + } + } + + return urls[0] +} + +func (s *settingProvider) SMTP(ctx context.Context) *SMTP { + return &SMTP{ + FromName: s.getString(ctx, "fromName", ""), + From: s.getString(ctx, "fromAdress", ""), + Host: s.getString(ctx, "smtpHost", ""), + ReplyTo: s.getString(ctx, "replyTo", ""), + User: s.getString(ctx, "smtpUser", ""), + Password: s.getString(ctx, "smtpPass", ""), + ForceEncryption: s.getBoolean(ctx, "smtpEncryption", false), + Port: s.getInt(ctx, "smtpPort", 25), + Keepalive: s.getInt(ctx, "mail_keepalive", 30), + } +} + +func (s *settingProvider) DefaultGroup(ctx context.Context) int { + return s.getInt(ctx, "default_group", 2) +} + +func (s *settingProvider) EmailActivationEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "email_active", false) +} + +func (s *settingProvider) TcCaptcha(ctx context.Context) *TcCaptcha { + return &TcCaptcha{ + AppID: s.getString(ctx, "captcha_TCaptcha_CaptchaAppId", ""), + AppSecretKey: s.getString(ctx, "captcha_TCaptcha_AppSecretKey", ""), + SecretID: s.getString(ctx, "captcha_TCaptcha_SecretId", ""), + SecretKey: s.getString(ctx, "captcha_TCaptcha_SecretKey", ""), + } +} + +func (s *settingProvider) TurnstileCaptcha(ctx context.Context) *Turnstile { + return &Turnstile{ + Secret: s.getString(ctx, "captcha_turnstile_site_secret", ""), + Key: s.getString(ctx, "captcha_turnstile_site_key", ""), + } +} + +func (s *settingProvider) ReCaptcha(ctx context.Context) *ReCaptcha { + return &ReCaptcha{ + Secret: s.getString(ctx, "captcha_ReCaptchaSecret", ""), + Key: s.getString(ctx, "captcha_ReCaptchaKey", ""), + } +} + +func (s *settingProvider) CaptchaType(ctx context.Context) CaptchaType { + return CaptchaType(s.getString(ctx, "captcha_type", string(CaptchaNormal))) +} + +func (s *settingProvider) RegCaptchaEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "reg_captcha", false) +} + +func (s *settingProvider) LoginCaptchaEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "login_captcha", false) +} + +func (s *settingProvider) ForgotPasswordCaptchaEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "forget_captcha", false) +} + +func (s *settingProvider) AuthnEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "authn_enabled", false) +} + +func (s *settingProvider) RegisterEnabled(ctx context.Context) bool { + return s.getBoolean(ctx, "register_enabled", false) +} + +func (s *settingProvider) SiteBasic(ctx context.Context) *SiteBasic { + return &SiteBasic{ + Name: s.getString(ctx, "siteName", ""), + Title: s.getString(ctx, "siteTitle", ""), + ID: s.getString(ctx, "siteID", ""), + Description: s.getString(ctx, "siteDes", ""), + Script: s.getString(ctx, "siteScript", ""), + } +} + +func (s *settingProvider) PWA(ctx context.Context) *PWASetting { + return &PWASetting{ + SmallIcon: s.getString(ctx, "pwa_small_icon", ""), + MediumIcon: s.getString(ctx, "pwa_medium_icon", ""), + LargeIcon: s.getString(ctx, "pwa_large_icon", ""), + Display: s.getString(ctx, "pwa_display", ""), + ThemeColor: s.getString(ctx, "pwa_theme_color", ""), + BackgroundColor: s.getString(ctx, "pwa_background_color", ""), + } +} + +func IsTrueValue(val string) bool { + return val == "1" || val == "true" +} + +func (s *settingProvider) getInt(ctx context.Context, name string, defaultVal int) int { + val := s.adapterChain.Get(ctx, name, defaultVal) + if intVal, ok := val.(int); ok { + return intVal + } + + strVal := val.(string) + if intVal, err := strconv.Atoi(strVal); err == nil { + return intVal + } + + return defaultVal +} + +func (s *settingProvider) getInt64(ctx context.Context, name string, defaultVal int64) int64 { + val := s.adapterChain.Get(ctx, name, defaultVal) + if intVal, ok := val.(int64); ok { + return intVal + } + + strVal := val.(string) + if intVal, err := strconv.ParseInt(strVal, 10, 64); err == nil { + return intVal + } + + return defaultVal +} + +func (s *settingProvider) getFloat64(ctx context.Context, name string, defaultVal float64) float64 { + val := s.adapterChain.Get(ctx, name, defaultVal) + if intVal, ok := val.(float64); ok { + return intVal + } + + strVal := val.(string) + if intVal, err := strconv.ParseFloat(strVal, 64); err == nil { + return intVal + } + + return defaultVal +} + +func (s *settingProvider) getBoolean(ctx context.Context, name string, defaultVal bool) bool { + val := s.adapterChain.Get(ctx, name, defaultVal) + if intVal, ok := val.(bool); ok { + return intVal + } + + strVal := val.(string) + return IsTrueValue(strVal) +} + +func (s *settingProvider) getString(ctx context.Context, name string, defaultVal string) string { + val := s.adapterChain.Get(ctx, name, defaultVal) + return val.(string) +} + +func (s *settingProvider) getStringList(ctx context.Context, name string, defaultVal []string) []string { + res, _ := s.getStringListRaw(ctx, name, defaultVal) + return res +} + +func (s *settingProvider) getStringListRaw(ctx context.Context, name string, defaultVal []string) ([]string, string) { + val := s.getString(ctx, name, stringListDefault) + if val == stringListDefault { + return defaultVal, val + } + + return strings.Split(val, stringListDefaultSeparator), val +} + +func (s *settingProvider) getBoolSet(ctx context.Context, name string) *boolset.BooleanSet { + val := s.getString(ctx, name, "") + if val == "" { + return defaultBoolSet + } + + res, err := boolset.FromString(val) + if err != nil { + return defaultBoolSet + } + + return res +} + +func UseFirstSiteUrl(ctx context.Context) context.Context { + return context.WithValue(ctx, UseFirstSiteUrlCtxKey{}, true) +} diff --git a/pkg/setting/types.go b/pkg/setting/types.go new file mode 100644 index 00000000..b4c685da --- /dev/null +++ b/pkg/setting/types.go @@ -0,0 +1,239 @@ +package setting + +import ( + "time" +) + +type PWASetting struct { + SmallIcon string + MediumIcon string + LargeIcon string + Display string + ThemeColor string + BackgroundColor string +} + +type SiteBasic struct { + Name string + Title string + ID string + Description string + Script string +} + +type CaptchaType string + +const ( + CaptchaNormal = CaptchaType("normal") + CaptchaReCaptcha = CaptchaType("recaptcha") + CaptchaTcaptcha = CaptchaType("tcaptcha") + CaptchaTurnstile = CaptchaType("turnstile") +) + +type ReCaptcha struct { + Key string + Secret string +} + +type TcCaptcha struct { + AppID string + AppSecretKey string + SecretID string + SecretKey string +} + +type Turnstile struct { + Key string + Secret string +} + +type SMTP struct { + FromName string + From string + Host string + ReplyTo string + User string + Password string + ForceEncryption bool + Port int + Keepalive int +} + +type TokenAuth struct { + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration +} + +type DBFS struct { + UseCursorPagination bool + MaxPageSize int + MaxRecursiveSearchedFolder int + UseSSEForSearch bool +} + +type ( + QueueType string + QueueSetting struct { + WorkerNum int + MaxExecution time.Duration + BackoffFactor float64 + BackoffMaxDuration time.Duration + MaxRetry int + RetryDelay time.Duration + } +) + +type ThumbEncode struct { + Quality int + Format string +} + +var ( + QueueTypeMediaMeta = QueueType("media_meta") + QueueTypeIOIntense = QueueType("io_intense") + QueueTypeThumb = QueueType("thumb") + QueueTypeEntityRecycle = QueueType("recycle") + QueueTypeSlave = QueueType("slave") + QueueTypeRemoteDownload = QueueType("remote_download") +) + +type CronType string + +var ( + CronTypeEntityCollect = CronType("entity_collect") + CronTypeTrashBinCollect = CronType("trash_bin_collect") + CronTypeOauthCredRefresh = CronType("oauth_cred_refresh") +) + +type Theme struct { + Themes string + DefaultTheme string +} + +type Logo struct { + Normal string + Light string +} + +type LegalDocuments struct { + PrivacyPolicy string + TermsOfService string +} + +type CaptchaMode int + +const ( + CaptchaModeNumber = CaptchaMode(iota) + CaptchaModeAlphabet + CaptchaModeArithmetic + CaptchaModeNumberAlphabet +) + +type Captcha struct { + Height int + Width int + Mode CaptchaMode + ComplexOfNoiseText int + ComplexOfNoiseDot int + IsShowHollowLine bool + IsShowNoiseDot bool + IsShowNoiseText bool + IsShowSlimeLine bool + IsShowSineLine bool + Length int +} + +type ExplorerFrontendSettings struct { + Icons string +} + +type MapProvider string + +const ( + MapProviderOpenStreetMap = MapProvider("openstreetmap") + MapProviderGoogle = MapProvider("google") +) + +type MapGoogleTileType string + +const ( + MapGoogleTileTypeRegular = MapGoogleTileType("regular") + MapGoogleTileTypeSatellite = MapGoogleTileType("satellite") + MapGoogleTileTypeTerrain = MapGoogleTileType("terrain") +) + +type MapSetting struct { + Provider MapProvider + GoogleTileType MapGoogleTileType +} + +// Viewer related + +type ( + ViewerAction string + ViewerType string +) + +const ( + ViewerActionView = "view" + ViewerActionEdit = "edit" + + ViewerTypeBuiltin = "builtin" + ViewerTypeWopi = "wopi" +) + +type Viewer struct { + ID string `json:"id"` + Type ViewerType `json:"type"` + DisplayName string `json:"display_name"` + Exts []string `json:"exts"` + Url string `json:"url,omitempty"` + Icon string `json:"icon,omitempty"` + WopiActions map[string]map[ViewerAction]string `json:"wopi_actions,omitempty"` + Props map[string]string `json:"props,omitempty"` + MaxSize int64 `json:"max_size,omitempty"` + Disabled bool `json:"disabled,omitempty"` + Templates []NewFileTemplate `json:"templates,omitempty"` +} + +type ViewerGroup struct { + Viewers []Viewer `json:"viewers"` +} + +type NewFileTemplate struct { + Ext string `json:"ext"` + DisplayName string `json:"display_name"` +} + +type ( + SearchCategory string +) + +const ( + CategoryUnknown = SearchCategory("unknown") + CategoryImage = SearchCategory("image") + CategoryVideo = SearchCategory("video") + CategoryAudio = SearchCategory("audio") + CategoryDocument = SearchCategory("document") +) + +type AppSetting struct { + Promotion bool +} + +type EmailTemplate struct { + Title string `json:"title"` + Body string `json:"body"` + Language string `json:"language"` +} + +type Avatar struct { + Gravatar string `json:"gravatar"` + Path string `json:"path"` +} + +type AvatarProcess struct { + Path string `json:"path"` + MaxFileSize int64 `json:"max_file_size"` + MaxWidth int `json:"max_width"` +} diff --git a/pkg/task/compress.go b/pkg/task/compress.go deleted file mode 100644 index 5e20a362..00000000 --- a/pkg/task/compress.go +++ /dev/null @@ -1,175 +0,0 @@ -package task - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// CompressTask 文件压缩任务 -type CompressTask struct { - User *model.User - TaskModel *model.Task - TaskProps CompressProps - Err *JobError - - zipPath string -} - -// CompressProps 压缩任务属性 -type CompressProps struct { - Dirs []uint `json:"dirs"` - Files []uint `json:"files"` - Dst string `json:"dst"` -} - -// Props 获取任务属性 -func (job *CompressTask) Props() string { - res, _ := json.Marshal(job.TaskProps) - return string(res) -} - -// Type 获取任务状态 -func (job *CompressTask) Type() int { - return CompressTaskType -} - -// Creator 获取创建者ID -func (job *CompressTask) Creator() uint { - return job.User.ID -} - -// Model 获取任务的数据库模型 -func (job *CompressTask) Model() *model.Task { - return job.TaskModel -} - -// SetStatus 设定状态 -func (job *CompressTask) SetStatus(status int) { - job.TaskModel.SetStatus(status) -} - -// SetError 设定任务失败信息 -func (job *CompressTask) SetError(err *JobError) { - job.Err = err - res, _ := json.Marshal(job.Err) - job.TaskModel.SetError(string(res)) - - // 删除压缩文件 - job.removeZipFile() -} - -func (job *CompressTask) removeZipFile() { - if job.zipPath != "" { - if err := os.Remove(job.zipPath); err != nil { - util.Log().Warning("Failed to delete temp zip file %q: %s", job.zipPath, err) - } - } -} - -// SetErrorMsg 设定任务失败信息 -func (job *CompressTask) SetErrorMsg(msg string) { - job.SetError(&JobError{Msg: msg}) -} - -// GetError 返回任务失败信息 -func (job *CompressTask) GetError() *JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *CompressTask) Do() { - // 创建文件系统 - fs, err := filesystem.NewFileSystem(job.User) - if err != nil { - job.SetErrorMsg(err.Error()) - return - } - - util.Log().Debug("Starting compress file...") - job.TaskModel.SetProgress(CompressingProgress) - - // 创建临时压缩文件 - saveFolder := "compress" - zipFilePath := filepath.Join( - util.RelativePath(model.GetSettingByName("temp_path")), - saveFolder, - fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()), - ) - zipFile, err := util.CreatNestedFile(zipFilePath) - if err != nil { - util.Log().Warning("%s", err) - job.SetErrorMsg(err.Error()) - return - } - - defer zipFile.Close() - - // 开始压缩 - ctx := context.Background() - err = fs.Compress(ctx, zipFile, job.TaskProps.Dirs, job.TaskProps.Files, false) - if err != nil { - job.SetErrorMsg(err.Error()) - return - } - - job.zipPath = zipFilePath - zipFile.Close() - util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath) - job.TaskModel.SetProgress(TransferringProgress) - - // 上传文件 - err = fs.UploadFromPath(ctx, zipFilePath, job.TaskProps.Dst, 0) - if err != nil { - job.SetErrorMsg(err.Error()) - return - } - - job.removeZipFile() -} - -// NewCompressTask 新建压缩任务 -func NewCompressTask(user *model.User, dst string, dirs, files []uint) (Job, error) { - newTask := &CompressTask{ - User: user, - TaskProps: CompressProps{ - Dirs: dirs, - Files: files, - Dst: dst, - }, - } - - record, err := Record(newTask) - if err != nil { - return nil, err - } - newTask.TaskModel = record - - return newTask, nil -} - -// NewCompressTaskFromModel 从数据库记录中恢复压缩任务 -func NewCompressTaskFromModel(task *model.Task) (Job, error) { - user, err := model.GetActiveUserByID(task.UserID) - if err != nil { - return nil, err - } - newTask := &CompressTask{ - User: &user, - TaskModel: task, - } - - err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) - if err != nil { - return nil, err - } - - return newTask, nil -} diff --git a/pkg/task/compress_test.go b/pkg/task/compress_test.go deleted file mode 100644 index 34b282dc..00000000 --- a/pkg/task/compress_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestCompressTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(CompressTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestCompressTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestCompressTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - zipPath: "test/TestCompressTask_SetError", - } - zipFile, _ := util.CreatNestedFile("test/TestCompressTask_SetError") - zipFile.Close() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(util.Exists("test/TestCompressTask_SetError")) - asserts.Equal("error", task.GetError().Msg) -} - -func TestCompressTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 压缩出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Dirs = []uint{1} - // 更新进度 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - // 查找目录 - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 上传出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - MaxSize: 1, - }, - } - task.TaskProps.Dirs = []uint{1} - cache.Set("setting_temp_path", "test", 0) - // 更新进度 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - // 查找目录 - mock.ExpectQuery("SELECT(.+)folders"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - // 查找文件 - mock.ExpectQuery("SELECT(.+)files"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 查找子文件 - mock.ExpectQuery("SELECT(.+)files"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 查找子目录 - mock.ExpectQuery("SELECT(.+)folders"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - asserts.True(util.IsEmpty(util.RelativePath("test/compress"))) - } -} - -func TestNewCompressTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewCompressTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewCompressTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewCompressTaskFromModel(&model.Task{Props: ""}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/decompress.go b/pkg/task/decompress.go deleted file mode 100644 index 9c6d88ea..00000000 --- a/pkg/task/decompress.go +++ /dev/null @@ -1,131 +0,0 @@ -package task - -import ( - "context" - "encoding/json" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" -) - -// DecompressTask 文件压缩任务 -type DecompressTask struct { - User *model.User - TaskModel *model.Task - TaskProps DecompressProps - Err *JobError - - zipPath string -} - -// DecompressProps 压缩任务属性 -type DecompressProps struct { - Src string `json:"src"` - Dst string `json:"dst"` - Encoding string `json:"encoding"` -} - -// Props 获取任务属性 -func (job *DecompressTask) Props() string { - res, _ := json.Marshal(job.TaskProps) - return string(res) -} - -// Type 获取任务状态 -func (job *DecompressTask) Type() int { - return DecompressTaskType -} - -// Creator 获取创建者ID -func (job *DecompressTask) Creator() uint { - return job.User.ID -} - -// Model 获取任务的数据库模型 -func (job *DecompressTask) Model() *model.Task { - return job.TaskModel -} - -// SetStatus 设定状态 -func (job *DecompressTask) SetStatus(status int) { - job.TaskModel.SetStatus(status) -} - -// SetError 设定任务失败信息 -func (job *DecompressTask) SetError(err *JobError) { - job.Err = err - res, _ := json.Marshal(job.Err) - job.TaskModel.SetError(string(res)) -} - -// SetErrorMsg 设定任务失败信息 -func (job *DecompressTask) SetErrorMsg(msg string, err error) { - jobErr := &JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - job.SetError(jobErr) -} - -// GetError 返回任务失败信息 -func (job *DecompressTask) GetError() *JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *DecompressTask) Do() { - // 创建文件系统 - fs, err := filesystem.NewFileSystem(job.User) - if err != nil { - job.SetErrorMsg("Failed to create filesystem.", err) - return - } - - job.TaskModel.SetProgress(DecompressingProgress) - - err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst, job.TaskProps.Encoding) - if err != nil { - job.SetErrorMsg("Failed to decompress file.", err) - return - } - -} - -// NewDecompressTask 新建压缩任务 -func NewDecompressTask(user *model.User, src, dst, encoding string) (Job, error) { - newTask := &DecompressTask{ - User: user, - TaskProps: DecompressProps{ - Src: src, - Dst: dst, - Encoding: encoding, - }, - } - - record, err := Record(newTask) - if err != nil { - return nil, err - } - newTask.TaskModel = record - - return newTask, nil -} - -// NewDecompressTaskFromModel 从数据库记录中恢复压缩任务 -func NewDecompressTaskFromModel(task *model.Task) (Job, error) { - user, err := model.GetActiveUserByID(task.UserID) - if err != nil { - return nil, err - } - newTask := &DecompressTask{ - User: &user, - TaskModel: task, - } - - err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) - if err != nil { - return nil, err - } - - return newTask, nil -} diff --git a/pkg/task/decompress_test.go b/pkg/task/decompress_test.go deleted file mode 100644 index 75b7cfe5..00000000 --- a/pkg/task/decompress_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestDecompressTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(DecompressTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestDecompressTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestDecompressTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - zipPath: "test/TestCompressTask_SetError", - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestDecompressTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 压缩文件不存在 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = "test" - task.Do() - asserts.NotEmpty(task.GetError().Msg) - } -} - -func TestNewDecompressTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewDecompressTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewDecompressTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewDecompressTaskFromModel(&model.Task{Props: ""}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/errors.go b/pkg/task/errors.go deleted file mode 100644 index f1fca169..00000000 --- a/pkg/task/errors.go +++ /dev/null @@ -1,8 +0,0 @@ -package task - -import "errors" - -var ( - // ErrUnknownTaskType 未知任务类型 - ErrUnknownTaskType = errors.New("unknown task type") -) diff --git a/pkg/task/import.go b/pkg/task/import.go deleted file mode 100644 index 607b4d1e..00000000 --- a/pkg/task/import.go +++ /dev/null @@ -1,221 +0,0 @@ -package task - -import ( - "context" - "encoding/json" - "path" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// ImportTask 导入务 -type ImportTask struct { - User *model.User - TaskModel *model.Task - TaskProps ImportProps - Err *JobError -} - -// ImportProps 导入任务属性 -type ImportProps struct { - PolicyID uint `json:"policy_id"` // 存储策略ID - Src string `json:"src"` // 原始路径 - Recursive bool `json:"is_recursive"` // 是否递归导入 - Dst string `json:"dst"` // 目的目录 -} - -// Props 获取任务属性 -func (job *ImportTask) Props() string { - res, _ := json.Marshal(job.TaskProps) - return string(res) -} - -// Type 获取任务状态 -func (job *ImportTask) Type() int { - return ImportTaskType -} - -// Creator 获取创建者ID -func (job *ImportTask) Creator() uint { - return job.User.ID -} - -// Model 获取任务的数据库模型 -func (job *ImportTask) Model() *model.Task { - return job.TaskModel -} - -// SetStatus 设定状态 -func (job *ImportTask) SetStatus(status int) { - job.TaskModel.SetStatus(status) -} - -// SetError 设定任务失败信息 -func (job *ImportTask) SetError(err *JobError) { - job.Err = err - res, _ := json.Marshal(job.Err) - job.TaskModel.SetError(string(res)) -} - -// SetErrorMsg 设定任务失败信息 -func (job *ImportTask) SetErrorMsg(msg string, err error) { - jobErr := &JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - job.SetError(jobErr) -} - -// GetError 返回任务失败信息 -func (job *ImportTask) GetError() *JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *ImportTask) Do() { - ctx := context.Background() - - // 查找存储策略 - policy, err := model.GetPolicyByID(job.TaskProps.PolicyID) - if err != nil { - job.SetErrorMsg("Policy not exist.", err) - return - } - - // 创建文件系统 - job.User.Policy = policy - fs, err := filesystem.NewFileSystem(job.User) - if err != nil { - job.SetErrorMsg(err.Error(), nil) - return - } - defer fs.Recycle() - - fs.Policy = &policy - if err := fs.DispatchHandler(); err != nil { - job.SetErrorMsg("Failed to dispatch policy.", err) - return - } - - // 注册钩子 - fs.Use("BeforeAddFile", filesystem.HookValidateFile) - fs.Use("BeforeAddFile", filesystem.HookValidateCapacity) - - // 列取目录、对象 - job.TaskModel.SetProgress(ListingProgress) - coxIgnoreConflict := context.WithValue(context.Background(), fsctx.IgnoreDirectoryConflictCtx, - true) - objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive) - if err != nil { - job.SetErrorMsg("Failed to list files.", err) - return - } - - job.TaskModel.SetProgress(InsertingProgress) - - // 虚拟目录路径与folder对象ID的对应 - pathCache := make(map[string]*model.Folder, len(objects)) - - // 插入目录记录到用户文件系统 - for _, object := range objects { - if object.IsDir { - // 创建目录 - virtualPath := path.Join(job.TaskProps.Dst, object.RelativePath) - folder, err := fs.CreateDirectory(coxIgnoreConflict, virtualPath) - if err != nil { - util.Log().Warning("Importing task cannot create user directory %q: %s", virtualPath, err) - } else if folder.ID > 0 { - pathCache[virtualPath] = folder - } - } - } - - // 插入文件记录到用户文件系统 - for _, object := range objects { - if !object.IsDir { - // 创建文件信息 - virtualPath := path.Dir(path.Join(job.TaskProps.Dst, object.RelativePath)) - fileHeader := fsctx.FileStream{ - Size: object.Size, - VirtualPath: virtualPath, - Name: object.Name, - SavePath: object.Source, - } - - // 查找父目录 - parentFolder := &model.Folder{} - if parent, ok := pathCache[virtualPath]; ok { - parentFolder = parent - } else { - folder, err := fs.CreateDirectory(context.Background(), virtualPath) - if err != nil { - util.Log().Warning("Importing task cannot create user directory %q: %s", - virtualPath, err) - continue - } - parentFolder = folder - - } - - // 插入文件记录 - _, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) - if err != nil { - util.Log().Warning("Importing task cannot insert user file %q: %s", - object.RelativePath, err) - if err == filesystem.ErrInsufficientCapacity { - job.SetErrorMsg("Insufficient storage capacity.", err) - return - } - } - - } - } -} - -// NewImportTask 新建导入任务 -func NewImportTask(user, policy uint, src, dst string, recursive bool) (Job, error) { - creator, err := model.GetActiveUserByID(user) - if err != nil { - return nil, err - } - - newTask := &ImportTask{ - User: &creator, - TaskProps: ImportProps{ - PolicyID: policy, - Recursive: recursive, - Src: src, - Dst: dst, - }, - } - - record, err := Record(newTask) - if err != nil { - return nil, err - } - newTask.TaskModel = record - - return newTask, nil -} - -// NewImportTaskFromModel 从数据库记录中恢复导入任务 -func NewImportTaskFromModel(task *model.Task) (Job, error) { - user, err := model.GetActiveUserByID(task.UserID) - if err != nil { - return nil, err - } - newTask := &ImportTask{ - User: &user, - TaskModel: task, - } - - err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) - if err != nil { - return nil, err - } - - return newTask, nil -} diff --git a/pkg/task/import_test.go b/pkg/task/import_test.go deleted file mode 100644 index a17123da..00000000 --- a/pkg/task/import_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestImportTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(ImportTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestImportTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestImportTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestImportTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - TaskProps: ImportProps{ - PolicyID: 63, - Src: "", - Recursive: false, - Dst: "", - }, - } - - // 存储策略不存在 - { - cache.Deletes([]string{"63"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 设定失败状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.Err.Error) - task.Err = nil - } - - // 无法分配 Filesystem - { - cache.Deletes([]string{"63"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "unknown")) - // 设定失败状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.Err.Msg) - task.Err = nil - } - - // 成功列取,但是文件为空 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "TestImportTask_Do/empty" - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } - - // 创建测试文件 - f, _ := util.CreatNestedFile(util.RelativePath("tests/TestImportTask_Do/test.txt")) - f.Close() - - // 成功列取,包含一个文件一个目录,父目录创建失败 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "tests" - task.TaskProps.Dst = "/" - task.TaskProps.Recursive = true - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 查找父目录,但是不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 仍然不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 创建文件时查找父目录,仍然不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - - task.Do() - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } - - // 成功列取,包含一个文件一个目录, 全部操作成功 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "tests" - task.TaskProps.Dst = "/" - task.TaskProps.Recursive = true - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 查找父目录,存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - // 查找同名文件,不存在 - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 创建目录 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)folders(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - // 插入文件记录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - - task.Do() - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } -} - -func TestNewImportTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewImportTask(1, 1, "/", "/", false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewImportTask(1, 1, "/", "/", false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewImportTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewImportTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewImportTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/job.go b/pkg/task/job.go deleted file mode 100644 index d4804924..00000000 --- a/pkg/task/job.go +++ /dev/null @@ -1,123 +0,0 @@ -package task - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// 任务类型 -const ( - // CompressTaskType 压缩任务 - CompressTaskType = iota - // DecompressTaskType 解压缩任务 - DecompressTaskType - // TransferTaskType 中转任务 - TransferTaskType - // ImportTaskType 导入任务 - ImportTaskType - // RecycleTaskType 回收任务 - RecycleTaskType -) - -// 任务状态 -const ( - // Queued 排队中 - Queued = iota - // Processing 处理中 - Processing - // Error 失败 - Error - // Canceled 取消 - Canceled - // Complete 完成 - Complete -) - -// 任务进度 -const ( - // PendingProgress 等待中 - PendingProgress = iota - // Compressing 压缩中 - CompressingProgress - // Decompressing 解压缩中 - DecompressingProgress - // Downloading 下载中 - DownloadingProgress - // Transferring 转存中 - TransferringProgress - // ListingProgress 索引中 - ListingProgress - // InsertingProgress 插入中 - InsertingProgress -) - -// Job 任务接口 -type Job interface { - Type() int // 返回任务类型 - Creator() uint // 返回创建者ID - Props() string // 返回序列化后的任务属性 - Model() *model.Task // 返回对应的数据库模型 - SetStatus(int) // 设定任务状态 - Do() // 开始执行任务 - SetError(*JobError) // 设定任务失败信息 - GetError() *JobError // 获取任务执行结果,返回nil表示成功完成执行 -} - -// JobError 任务失败信息 -type JobError struct { - Msg string `json:"msg,omitempty"` - Error string `json:"error,omitempty"` -} - -// Record 将任务记录到数据库中 -func Record(job Job) (*model.Task, error) { - record := model.Task{ - Status: Queued, - Type: job.Type(), - UserID: job.Creator(), - Progress: 0, - Error: "", - Props: job.Props(), - } - _, err := record.Create() - return &record, err -} - -// Resume 从数据库中恢复未完成任务 -func Resume(p Pool) { - tasks := model.GetTasksByStatus(Queued, Processing) - if len(tasks) == 0 { - return - } - util.Log().Info("Resume %d unfinished task(s) from database.", len(tasks)) - - for i := 0; i < len(tasks); i++ { - job, err := GetJobFromModel(&tasks[i]) - if err != nil { - util.Log().Warning("Failed to resume task: %s", err) - continue - } - - if job != nil { - p.Submit(job) - } - } -} - -// GetJobFromModel 从数据库给定模型获取任务 -func GetJobFromModel(task *model.Task) (Job, error) { - switch task.Type { - case CompressTaskType: - return NewCompressTaskFromModel(task) - case DecompressTaskType: - return NewDecompressTaskFromModel(task) - case TransferTaskType: - return NewTransferTaskFromModel(task) - case ImportTaskType: - return NewImportTaskFromModel(task) - case RecycleTaskType: - return NewRecycleTaskFromModel(task) - default: - return nil, ErrUnknownTaskType - } -} diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go deleted file mode 100644 index 737f5b76..00000000 --- a/pkg/task/job_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRecord(t *testing.T) { - asserts := assert.New(t) - job := &TransferTask{ - User: &model.User{Policy: model.Policy{Type: "unknown"}}, - } - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err := Record(job) - asserts.NoError(err) -} - -type taskPoolMock struct { - testMock.Mock -} - -func (t taskPoolMock) Add(num int) { - t.Called(num) -} - -func (t taskPoolMock) Submit(job Job) { - t.Called(job) -} - -func TestResume(t *testing.T) { - asserts := assert.New(t) - mockPool := taskPoolMock{} - - // 没有任务 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"})) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 有任务, 类型未知 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(233)) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 有任务 - { - mockPool.On("Submit", testMock.Anything) - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type", "props"}).AddRow(CompressTaskType, "{}")) - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - mockPool.AssertExpectations(t) - } -} - -func TestGetJobFromModel(t *testing.T) { - asserts := assert.New(t) - - // CompressTaskType - { - task := &model.Task{ - Status: 0, - Type: CompressTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // DecompressTaskType - { - task := &model.Task{ - Status: 0, - Type: DecompressTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // TransferTaskType - { - task := &model.Task{ - Status: 0, - Type: TransferTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // RecycleTaskType - { - task := &model.Task{ - Status: 0, - Type: RecycleTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} diff --git a/pkg/task/pool.go b/pkg/task/pool.go deleted file mode 100644 index e37f179b..00000000 --- a/pkg/task/pool.go +++ /dev/null @@ -1,68 +0,0 @@ -package task - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// TaskPoll 要使用的任务池 -var TaskPoll Pool - -type Pool interface { - Add(num int) - Submit(job Job) -} - -// AsyncPool 带有最大配额的任务池 -type AsyncPool struct { - // 容量 - idleWorker chan int -} - -// Add 增加可用Worker数量 -func (pool *AsyncPool) Add(num int) { - for i := 0; i < num; i++ { - pool.idleWorker <- 1 - } -} - -// ObtainWorker 阻塞直到获取新的Worker -func (pool *AsyncPool) obtainWorker() Worker { - select { - case <-pool.idleWorker: - // 有空闲Worker名额时,返回新Worker - return &GeneralWorker{} - } -} - -// FreeWorker 添加空闲Worker -func (pool *AsyncPool) freeWorker() { - pool.Add(1) -} - -// Submit 开始提交任务 -func (pool *AsyncPool) Submit(job Job) { - go func() { - util.Log().Debug("Waiting for Worker.") - worker := pool.obtainWorker() - util.Log().Debug("Worker obtained.") - worker.Do(job) - util.Log().Debug("Worker released.") - pool.freeWorker() - }() -} - -// Init 初始化任务池 -func Init() { - maxWorker := model.GetIntSetting("max_worker_num", 10) - TaskPoll = &AsyncPool{ - idleWorker: make(chan int, maxWorker), - } - TaskPoll.Add(maxWorker) - util.Log().Info("Initialize task queue with WorkerNum = %d", maxWorker) - - if conf.SystemConfig.Mode == "master" { - Resume(TaskPoll) - } -} diff --git a/pkg/task/pool_test.go b/pkg/task/pool_test.go deleted file mode 100644 index fbe41340..00000000 --- a/pkg/task/pool_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package task - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInit(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_max_worker_num", "10", 0) - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(-1)) - Init() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(TaskPoll.(*AsyncPool).idleWorker, 10) -} - -func TestPool_Submit(t *testing.T) { - asserts := assert.New(t) - pool := &AsyncPool{ - idleWorker: make(chan int, 1), - } - pool.Add(1) - job := &MockJob{ - DoFunc: func() { - - }, - } - asserts.NotPanics(func() { - pool.Submit(job) - }) -} diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go deleted file mode 100644 index 60cc97f1..00000000 --- a/pkg/task/recycle.go +++ /dev/null @@ -1,130 +0,0 @@ -package task - -import ( - "encoding/json" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// RecycleTask 文件回收任务 -type RecycleTask struct { - User *model.User - TaskModel *model.Task - TaskProps RecycleProps - Err *JobError -} - -// RecycleProps 回收任务属性 -type RecycleProps struct { - // 下载任务 GID - DownloadGID string `json:"download_gid"` -} - -// Props 获取任务属性 -func (job *RecycleTask) Props() string { - res, _ := json.Marshal(job.TaskProps) - return string(res) -} - -// Type 获取任务状态 -func (job *RecycleTask) Type() int { - return RecycleTaskType -} - -// Creator 获取创建者ID -func (job *RecycleTask) Creator() uint { - return job.User.ID -} - -// Model 获取任务的数据库模型 -func (job *RecycleTask) Model() *model.Task { - return job.TaskModel -} - -// SetStatus 设定状态 -func (job *RecycleTask) SetStatus(status int) { - job.TaskModel.SetStatus(status) -} - -// SetError 设定任务失败信息 -func (job *RecycleTask) SetError(err *JobError) { - job.Err = err - res, _ := json.Marshal(job.Err) - job.TaskModel.SetError(string(res)) -} - -// SetErrorMsg 设定任务失败信息 -func (job *RecycleTask) SetErrorMsg(msg string, err error) { - jobErr := &JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - job.SetError(jobErr) -} - -// GetError 返回任务失败信息 -func (job *RecycleTask) GetError() *JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *RecycleTask) Do() { - download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID) - if err != nil { - util.Log().Warning("Recycle task %d cannot found download record.", job.TaskModel.ID) - job.SetErrorMsg("Cannot found download task.", err) - return - } - nodeID := download.GetNodeID() - node := cluster.Default.GetNodeByID(nodeID) - if node == nil { - util.Log().Warning("Recycle task %d cannot found node.", job.TaskModel.ID) - job.SetErrorMsg("Invalid slave node.", nil) - return - } - err = node.GetAria2Instance().DeleteTempFile(download) - if err != nil { - util.Log().Warning("Failed to delete transfer temp folder %q: %s", download.Parent, err) - job.SetErrorMsg("Failed to recycle files.", err) - return - } -} - -// NewRecycleTask 新建回收任务 -func NewRecycleTask(download *model.Download) (Job, error) { - newTask := &RecycleTask{ - User: download.GetOwner(), - TaskProps: RecycleProps{ - DownloadGID: download.GID, - }, - } - - record, err := Record(newTask) - if err != nil { - return nil, err - } - newTask.TaskModel = record - - return newTask, nil -} - -// NewRecycleTaskFromModel 从数据库记录中恢复回收任务 -func NewRecycleTaskFromModel(task *model.Task) (Job, error) { - user, err := model.GetActiveUserByID(task.UserID) - if err != nil { - return nil, err - } - newTask := &RecycleTask{ - User: &user, - TaskModel: task, - } - - err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) - if err != nil { - return nil, err - } - - return newTask, nil -} diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go deleted file mode 100644 index 0092a30c..00000000 --- a/pkg/task/recycle_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestRecycleTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(RecycleTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestRecycleTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestRecycleTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestNewRecycleTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewRecycleTask(&model.Download{ - Model: gorm.Model{ID: 1}, - GID: "test_g_id", - Parent: "/", - UserID: 1, - NodeID: 1, - }) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewRecycleTask(&model.Download{ - Model: gorm.Model{ID: 1}, - GID: "test_g_id", - Parent: "test/not_exist", - UserID: 1, - NodeID: 1, - }) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewRecycleTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go deleted file mode 100644 index bdc59260..00000000 --- a/pkg/task/slavetask/transfer.go +++ /dev/null @@ -1,138 +0,0 @@ -package slavetask - -import ( - "context" - "os" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// TransferTask 文件中转任务 -type TransferTask struct { - Err *task.JobError - Req *serializer.SlaveTransferReq - MasterID string -} - -// Props 获取任务属性 -func (job *TransferTask) Props() string { - return "" -} - -// Type 获取任务类型 -func (job *TransferTask) Type() int { - return 0 -} - -// Creator 获取创建者ID -func (job *TransferTask) Creator() uint { - return 0 -} - -// Model 获取任务的数据库模型 -func (job *TransferTask) Model() *model.Task { - return nil -} - -// SetStatus 设定状态 -func (job *TransferTask) SetStatus(status int) { -} - -// SetError 设定任务失败信息 -func (job *TransferTask) SetError(err *task.JobError) { - job.Err = err - -} - -// SetErrorMsg 设定任务失败信息 -func (job *TransferTask) SetErrorMsg(msg string, err error) { - jobErr := &task.JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - - job.SetError(jobErr) - - notifyMsg := mq.Message{ - TriggeredBy: job.MasterID, - Event: serializer.SlaveTransferFailed, - Content: serializer.SlaveTransferResult{ - Error: err.Error(), - }, - } - - if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { - util.Log().Warning("Failed to send transfer failure notification to master node: %s", err) - } -} - -// GetError 返回任务失败信息 -func (job *TransferTask) GetError() *task.JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *TransferTask) Do() { - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - job.SetErrorMsg("Failed to initialize anonymous filesystem.", err) - return - } - - fs.Policy = job.Req.Policy - if err := fs.DispatchHandler(); err != nil { - job.SetErrorMsg("Failed to dispatch policy.", err) - return - } - - master, err := cluster.DefaultController.GetMasterInfo(job.MasterID) - if err != nil { - job.SetErrorMsg("Cannot found master node ID.", err) - return - } - - fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) - file, err := os.Open(util.RelativePath(job.Req.Src)) - if err != nil { - job.SetErrorMsg("Failed to read source file.", err) - return - } - - defer file.Close() - - // 获取源文件大小 - fi, err := file.Stat() - if err != nil { - job.SetErrorMsg("Failed to get source file size.", err) - return - } - - size := fi.Size() - - err = fs.Handler.Put(context.Background(), &fsctx.FileStream{ - File: file, - SavePath: job.Req.Dst, - Size: uint64(size), - }) - if err != nil { - job.SetErrorMsg("Upload failed.", err) - return - } - - msg := mq.Message{ - TriggeredBy: job.MasterID, - Event: serializer.SlaveTransferSuccess, - Content: serializer.SlaveTransferResult{}, - } - - if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { - util.Log().Warning("Failed to send transfer success notification to master node: %s", err) - } -} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go deleted file mode 100644 index 54bba479..00000000 --- a/pkg/task/tranfer.go +++ /dev/null @@ -1,190 +0,0 @@ -package task - -import ( - "context" - "encoding/json" - "fmt" - "path" - "path/filepath" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// TransferTask 文件中转任务 -type TransferTask struct { - User *model.User - TaskModel *model.Task - TaskProps TransferProps - Err *JobError - - zipPath string -} - -// TransferProps 中转任务属性 -type TransferProps struct { - Src []string `json:"src"` // 原始文件 - SrcSizes map[string]uint64 `json:"src_size"` // 原始文件的大小信息,从机转存时使用 - Parent string `json:"parent"` // 父目录 - Dst string `json:"dst"` // 目的目录ID - // 将会保留原始文件的目录结构,Src 除去 Parent 开头作为最终路径 - TrimPath bool `json:"trim_path"` - // 负责处理中专任务的节点ID - NodeID uint `json:"node_id"` -} - -// Props 获取任务属性 -func (job *TransferTask) Props() string { - res, _ := json.Marshal(job.TaskProps) - return string(res) -} - -// Type 获取任务状态 -func (job *TransferTask) Type() int { - return TransferTaskType -} - -// Creator 获取创建者ID -func (job *TransferTask) Creator() uint { - return job.User.ID -} - -// Model 获取任务的数据库模型 -func (job *TransferTask) Model() *model.Task { - return job.TaskModel -} - -// SetStatus 设定状态 -func (job *TransferTask) SetStatus(status int) { - job.TaskModel.SetStatus(status) -} - -// SetError 设定任务失败信息 -func (job *TransferTask) SetError(err *JobError) { - job.Err = err - res, _ := json.Marshal(job.Err) - job.TaskModel.SetError(string(res)) - -} - -// SetErrorMsg 设定任务失败信息 -func (job *TransferTask) SetErrorMsg(msg string, err error) { - jobErr := &JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - job.SetError(jobErr) -} - -// GetError 返回任务失败信息 -func (job *TransferTask) GetError() *JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *TransferTask) Do() { - // 创建文件系统 - fs, err := filesystem.NewFileSystem(job.User) - if err != nil { - job.SetErrorMsg(err.Error(), nil) - return - } - - successCount := 0 - errorList := make([]string, 0, len(job.TaskProps.Src)) - for _, file := range job.TaskProps.Src { - dst := path.Join(job.TaskProps.Dst, filepath.Base(file)) - if job.TaskProps.TrimPath { - // 保留原始目录 - trim := util.FormSlash(job.TaskProps.Parent) - src := util.FormSlash(file) - dst = path.Join(job.TaskProps.Dst, strings.TrimPrefix(src, trim)) - } - - if job.TaskProps.NodeID > 1 { - // 指定为从机中转 - - // 获取从机节点 - node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) - if node == nil { - job.SetErrorMsg("Invalid slave node.", nil) - } - - // 切换为从机节点处理上传 - fs.SwitchToSlaveHandler(node) - err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{ - File: nil, - Size: job.TaskProps.SrcSizes[file], - Name: path.Base(dst), - VirtualPath: path.Dir(dst), - Src: file, - }, false) - } else { - // 主机节点中转 - err = fs.UploadFromPath(context.Background(), file, dst, 0) - } - - if err != nil { - errorList = append(errorList, err.Error()) - } else { - successCount++ - job.TaskModel.SetProgress(successCount) - } - } - - if len(errorList) > 0 { - job.SetErrorMsg("Failed to transfer one or more file(s).", fmt.Errorf(strings.Join(errorList, "\n"))) - } - -} - -// NewTransferTask 新建中转任务 -func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { - creator, err := model.GetActiveUserByID(user) - if err != nil { - return nil, err - } - - newTask := &TransferTask{ - User: &creator, - TaskProps: TransferProps{ - Src: src, - Parent: parent, - Dst: dst, - TrimPath: trim, - NodeID: node, - SrcSizes: sizes, - }, - } - - record, err := Record(newTask) - if err != nil { - return nil, err - } - newTask.TaskModel = record - - return newTask, nil -} - -// NewTransferTaskFromModel 从数据库记录中恢复中转任务 -func NewTransferTaskFromModel(task *model.Task) (Job, error) { - user, err := model.GetActiveUserByID(task.UserID) - if err != nil { - return nil, err - } - newTask := &TransferTask{ - User: &user, - TaskModel: task, - } - - err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) - if err != nil { - return nil, err - } - - return newTask, nil -} diff --git a/pkg/task/transfer_test.go b/pkg/task/transfer_test.go deleted file mode 100644 index 612a4538..00000000 --- a/pkg/task/transfer_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestTransferTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(TransferTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestTransferTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTransferTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestTransferTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.TaskProps.Parent = "test/not_exist" - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 上传出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = []string{"test/not_exist"} - task.TaskProps.Parent = "test/not_exist" - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 替换目录前缀 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = []string{"test/not_exist"} - task.TaskProps.Parent = "test/not_exist" - task.TaskProps.TrimPath = true - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } -} - -func TestNewTransferTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewTransferTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewTransferTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewTransferTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/worker.go b/pkg/task/worker.go deleted file mode 100644 index e40a3b5a..00000000 --- a/pkg/task/worker.go +++ /dev/null @@ -1,44 +0,0 @@ -package task - -import ( - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// Worker 处理任务的对象 -type Worker interface { - Do(Job) // 执行任务 -} - -// GeneralWorker 通用Worker -type GeneralWorker struct { -} - -// Do 执行任务 -func (worker *GeneralWorker) Do(job Job) { - util.Log().Debug("Start executing task.") - job.SetStatus(Processing) - - defer func() { - // 致命错误捕获 - if err := recover(); err != nil { - util.Log().Debug("Failed to execute task: %s", err) - job.SetError(&JobError{Msg: "Fatal error.", Error: fmt.Sprintf("%s", err)}) - job.SetStatus(Error) - } - }() - - // 开始执行任务 - job.Do() - - // 任务执行失败 - if err := job.GetError(); err != nil { - util.Log().Debug("Failed to execute task.") - job.SetStatus(Error) - return - } - - util.Log().Debug("Task finished.") - // 执行完成 - job.SetStatus(Complete) -} diff --git a/pkg/task/worker_test.go b/pkg/task/worker_test.go deleted file mode 100644 index 64c6551c..00000000 --- a/pkg/task/worker_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package task - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -type MockJob struct { - Err *JobError - Status int - DoFunc func() -} - -func (job *MockJob) Type() int { - panic("implement me") -} - -func (job *MockJob) Creator() uint { - panic("implement me") -} - -func (job *MockJob) Props() string { - panic("implement me") -} - -func (job *MockJob) Model() *model.Task { - panic("implement me") -} - -func (job *MockJob) SetStatus(status int) { - job.Status = status -} - -func (job *MockJob) Do() { - job.DoFunc() -} - -func (job *MockJob) SetError(*JobError) { -} - -func (job *MockJob) GetError() *JobError { - return job.Err -} - -func TestGeneralWorker_Do(t *testing.T) { - asserts := assert.New(t) - worker := &GeneralWorker{} - job := &MockJob{} - - // 正常 - { - job.DoFunc = func() { - } - worker.Do(job) - asserts.Equal(Complete, job.Status) - } - - // 有错误 - { - job.DoFunc = func() { - } - job.Status = Queued - job.Err = &JobError{Msg: "error"} - worker.Do(job) - asserts.Equal(Error, job.Status) - } - - // 有致命错误 - { - job.DoFunc = func() { - panic("mock fatal error") - } - job.Status = Queued - job.Err = nil - worker.Do(job) - asserts.Equal(Error, job.Status) - } - -} diff --git a/pkg/thumb/builtin.go b/pkg/thumb/builtin.go index 206d0465..3a9f2ff4 100644 --- a/pkg/thumb/builtin.go +++ b/pkg/thumb/builtin.go @@ -3,24 +3,21 @@ package thumb import ( "context" "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gofrs/uuid" "image" "image/gif" "image/jpeg" "image/png" "io" "path/filepath" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" //"github.com/nfnt/resize" "golang.org/x/image/draw" ) -func init() { - RegisterGenerator(&Builtin{}) -} +const thumbTempFolder = "thumb" // Thumb 缩略图 type Thumb struct { @@ -30,16 +27,15 @@ type Thumb struct { // NewThumbFromFile 从文件数据获取新的Thumb对象, // 尝试通过文件名name解码图像 -func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) { - ext := strings.ToLower(filepath.Ext(name)) +func NewThumbFromFile(file io.Reader, ext string) (*Thumb, error) { // 无扩展名时 - if len(ext) == 0 { + if ext == "" { return nil, fmt.Errorf("unknown image format: %w", ErrPassThrough) } var err error var img image.Image - switch ext[1:] { + switch ext { case "jpg", "jpeg": img, err = jpeg.Decode(file) case "gif": @@ -47,7 +43,7 @@ func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) { case "png": img, err = png.Decode(file) default: - return nil, fmt.Errorf("unknown image format: %w", ErrPassThrough) + return nil, fmt.Errorf("unknown image format %q: %w", ext, ErrPassThrough) } if err != nil { return nil, fmt.Errorf("failed to parse image: %w (%w)", err, ErrPassThrough) @@ -72,12 +68,12 @@ func (image *Thumb) GetSize() (int, int) { } // Save 保存图像到给定路径 -func (image *Thumb) Save(w io.Writer) (err error) { - switch model.GetSettingByNameWithDefault("thumb_encode_method", "jpg") { +func (image *Thumb) Save(w io.Writer, encodeSetting *setting.ThumbEncode) (err error) { + switch encodeSetting.Format { case "png": err = png.Encode(w, image.src) default: - err = jpeg.Encode(w, image.src, &jpeg.Options{Quality: model.GetIntSetting("thumb_encode_quality", 85)}) + err = jpeg.Encode(w, image.src, &jpeg.Options{Quality: encodeSetting.Quality}) } return err @@ -127,46 +123,35 @@ func Resize(newWidth, newHeight uint, img image.Image) image.Image { } // CreateAvatar 创建头像 -func (image *Thumb) CreateAvatar(uid uint) error { - // 读取头像相关设定 - savePath := util.RelativePath(model.GetSettingByName("avatar_path")) - s := model.GetIntSetting("avatar_size_s", 50) - m := model.GetIntSetting("avatar_size_m", 130) - l := model.GetIntSetting("avatar_size_l", 200) - - // 生成头像缩略图 - src := image.src - for k, size := range []int{s, m, l} { - out, err := util.CreatNestedFile(filepath.Join(savePath, fmt.Sprintf("avatar_%d_%d.png", uid, k))) - - if err != nil { - return err - } - defer out.Close() - - image.src = Resize(uint(size), uint(size), src) - err = image.Save(out) - if err != nil { - return err - } - } +func (image *Thumb) CreateAvatar(width int) { + image.src = Resize(uint(width), uint(width), image.src) +} - return nil +type Builtin struct { + settings setting.Provider +} +func NewBuiltinGenerator(settings setting.Provider) *Builtin { + return &Builtin{ + settings: settings, + } } -type Builtin struct{} +func (b Builtin) Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) { + if es.Entity().Size() > b.settings.BuiltinThumbMaxSize(ctx) { + return nil, fmt.Errorf("file is too big: %w", ErrPassThrough) + } -func (b Builtin) Generate(ctx context.Context, file io.Reader, src, name string, options map[string]string) (*Result, error) { - img, err := NewThumbFromFile(file, name) + img, err := NewThumbFromFile(es, ext) if err != nil { return nil, err } - img.GetThumb(thumbSize(options)) + w, h := b.settings.ThumbSize(ctx) + img.GetThumb(uint(w), uint(h)) tempPath := filepath.Join( - util.RelativePath(model.GetSettingByName("temp_path")), - "thumb", + util.DataPath(b.settings.TempPath(ctx)), + thumbTempFolder, fmt.Sprintf("thumb_%s", uuid.Must(uuid.NewV4()).String()), ) @@ -176,8 +161,8 @@ func (b Builtin) Generate(ctx context.Context, file io.Reader, src, name string, } defer thumbFile.Close() - if err := img.Save(thumbFile); err != nil { - return nil, err + if err := img.Save(thumbFile, b.settings.ThumbEncode(ctx)); err != nil { + return &Result{Path: tempPath}, err } return &Result{Path: tempPath}, nil @@ -187,6 +172,6 @@ func (b Builtin) Priority() int { return 300 } -func (b Builtin) EnableFlag() string { - return "thumb_builtin_enabled" +func (b Builtin) Enabled(ctx context.Context) bool { + return b.settings.BuiltinThumbGeneratorEnabled(ctx) } diff --git a/pkg/thumb/ffmpeg.go b/pkg/thumb/ffmpeg.go index 45814b37..c94a8297 100644 --- a/pkg/thumb/ffmpeg.go +++ b/pkg/thumb/ffmpeg.go @@ -4,90 +4,77 @@ import ( "bytes" "context" "fmt" - "io" - "os" "os/exec" "path/filepath" - "strings" + "time" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gofrs/uuid" ) -func init() { - RegisterGenerator(&FfmpegGenerator{}) +const ( + urlTimeout = time.Duration(1) * time.Hour +) + +func NewFfmpegGenerator(l logging.Logger, settings setting.Provider) *FfmpegGenerator { + return &FfmpegGenerator{l: l, settings: settings} } type FfmpegGenerator struct { - exts []string - lastRawExts string + l logging.Logger + settings setting.Provider } -func (f *FfmpegGenerator) Generate(ctx context.Context, file io.Reader, src, name string, options map[string]string) (*Result, error) { - const ( - thumbFFMpegPath = "thumb_ffmpeg_path" - thumbFFMpegExts = "thumb_ffmpeg_exts" - thumbFFMpegSeek = "thumb_ffmpeg_seek" - thumbEncodeMethod = "thumb_encode_method" - tempPath = "temp_path" - ) - ffmpegOpts := model.GetSettingByNames(thumbFFMpegPath, thumbFFMpegExts, thumbFFMpegSeek, thumbEncodeMethod, tempPath) - - if f.lastRawExts != ffmpegOpts[thumbFFMpegExts] { - f.exts = strings.Split(ffmpegOpts[thumbFFMpegExts], ",") - f.lastRawExts = ffmpegOpts[thumbFFMpegExts] +func (f *FfmpegGenerator) Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) { + if !util.IsInExtensionListExt(f.settings.FFMpegThumbExts(ctx), ext) { + return nil, fmt.Errorf("unsupported video format: %w", ErrPassThrough) } - if !util.IsInExtensionList(f.exts, name) { - return nil, fmt.Errorf("unsupported video format: %w", ErrPassThrough) + if es.Entity().Size() > f.settings.FFMpegThumbMaxSize(ctx) { + return nil, fmt.Errorf("file is too big: %w", ErrPassThrough) } tempOutputPath := filepath.Join( - util.RelativePath(ffmpegOpts[tempPath]), - "thumb", - fmt.Sprintf("thumb_%s.%s", uuid.Must(uuid.NewV4()).String(), ffmpegOpts[thumbEncodeMethod]), + util.DataPath(f.settings.TempPath(ctx)), + thumbTempFolder, + fmt.Sprintf("thumb_%s.%s", uuid.Must(uuid.NewV4()).String(), f.settings.ThumbEncode(ctx).Format), ) - tempInputPath := src - if tempInputPath == "" { - // If not local policy files, download to temp folder - tempInputPath = filepath.Join( - util.RelativePath(ffmpegOpts[tempPath]), - "thumb", - fmt.Sprintf("ffmpeg_%s%s", uuid.Must(uuid.NewV4()).String(), filepath.Ext(name)), - ) - - // Due to limitations of ffmpeg, we need to write the input file to disk first - tempInputFile, err := util.CreatNestedFile(tempInputPath) - if err != nil { - return nil, fmt.Errorf("failed to create temp file: %w", err) - } - - defer os.Remove(tempInputPath) - defer tempInputFile.Close() + if err := util.CreatNestedFolder(filepath.Dir(tempOutputPath)); err != nil { + return nil, fmt.Errorf("failed to create temp folder: %w", err) + } - if _, err = io.Copy(tempInputFile, file); err != nil { - return nil, fmt.Errorf("failed to write input file: %w", err) + input := "" + expire := time.Now().Add(urlTimeout) + if es.IsLocal() { + input = es.LocalPath(ctx) + } else { + src, err := es.Url(driver.WithForcePublicEndpoint(ctx, false), entitysource.WithNoInternalProxy(), entitysource.WithContext(ctx), entitysource.WithExpire(&expire)) + if err != nil { + return &Result{Path: tempOutputPath}, fmt.Errorf("failed to get entity url: %w", err) } - tempInputFile.Close() + input = src.Url } // Invoke ffmpeg - scaleOpt := fmt.Sprintf("scale=%s:%s:force_original_aspect_ratio=decrease", options["thumb_width"], options["thumb_height"]) + w, h := f.settings.ThumbSize(ctx) + scaleOpt := fmt.Sprintf("scale=%d:%d:force_original_aspect_ratio=decrease", w, h) cmd := exec.CommandContext(ctx, - ffmpegOpts[thumbFFMpegPath], "-ss", ffmpegOpts[thumbFFMpegSeek], "-i", tempInputPath, + f.settings.FFMpegPath(ctx), "-ss", f.settings.FFMpegThumbSeek(ctx), "-i", input, "-vf", scaleOpt, "-vframes", "1", tempOutputPath) // Redirect IO var stdErr bytes.Buffer - cmd.Stdin = file cmd.Stderr = &stdErr if err := cmd.Run(); err != nil { - util.Log().Warning("Failed to invoke ffmpeg: %s", stdErr.String()) - return nil, fmt.Errorf("failed to invoke ffmpeg: %w", err) + f.l.Warning("Failed to invoke ffmpeg: %s", stdErr.String()) + return &Result{Path: tempOutputPath}, fmt.Errorf("failed to invoke ffmpeg: %w, raw output: %s", err, stdErr.String()) } return &Result{Path: tempOutputPath}, nil @@ -97,6 +84,6 @@ func (f *FfmpegGenerator) Priority() int { return 200 } -func (f *FfmpegGenerator) EnableFlag() string { - return "thumb_ffmpeg_enabled" +func (f *FfmpegGenerator) Enabled(ctx context.Context) bool { + return f.settings.FFMpegThumbGeneratorEnabled(ctx) } diff --git a/pkg/thumb/libraw.go b/pkg/thumb/libraw.go deleted file mode 100644 index 089e5f80..00000000 --- a/pkg/thumb/libraw.go +++ /dev/null @@ -1,283 +0,0 @@ -package thumb - -import ( - "bytes" - "context" - "errors" - "fmt" - "image" - "image/jpeg" - "image/png" - "io" - "os" - "os/exec" - "path/filepath" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" -) - -func init() { - RegisterGenerator(&LibRawGenerator{}) -} - -type LibRawGenerator struct { - exts []string - lastRawExts string -} - -func (f *LibRawGenerator) Generate(ctx context.Context, file io.Reader, _ string, name string, options map[string]string) (*Result, error) { - const ( - thumbLibRawPath = "thumb_libraw_path" - thumbLibRawExt = "thumb_libraw_exts" - thumbTempPath = "temp_path" - ) - - opts := model.GetSettingByNames(thumbLibRawPath, thumbLibRawExt, thumbTempPath) - - if f.lastRawExts != opts[thumbLibRawExt] { - f.exts = strings.Split(opts[thumbLibRawExt], ",") - f.lastRawExts = opts[thumbLibRawExt] - } - - if !util.IsInExtensionList(f.exts, name) { - return nil, fmt.Errorf("unsupported image format: %w", ErrPassThrough) - } - - inputFilePath := filepath.Join( - util.RelativePath(opts[thumbTempPath]), - "thumb", - fmt.Sprintf("thumb_%s", uuid.Must(uuid.NewV4()).String()), - ) - defer func() { _ = os.Remove(inputFilePath) }() - - inputFile, err := util.CreatNestedFile(inputFilePath) - if err != nil { - return nil, fmt.Errorf("failed to create temp file: %w", err) - } - - if _, err = io.Copy(inputFile, file); err != nil { - _ = inputFile.Close() - return nil, fmt.Errorf("failed to write input file: %w", err) - } - _ = inputFile.Close() - - cmd := exec.CommandContext(ctx, opts[thumbLibRawPath], "-e", inputFilePath) - - var stdErr bytes.Buffer - cmd.Stderr = &stdErr - if err = cmd.Run(); err != nil { - util.Log().Warning("Failed to invoke LibRaw: %s", stdErr.String()) - return nil, fmt.Errorf("failed to invoke LibRaw: %w", err) - } - - outputFilePath := inputFilePath + ".thumb.jpg" - defer func() { _ = os.Remove(outputFilePath) }() - - ff, err := os.Open(outputFilePath) - if err != nil { - return nil, fmt.Errorf("failed to open temp file: %w", err) - } - defer func() { _ = ff.Close() }() - - // use builtin generator - result, err := new(Builtin).Generate(ctx, ff, outputFilePath, filepath.Base(outputFilePath), options) - if err != nil { - return nil, fmt.Errorf("failed to generate thumbnail: %w", err) - } - - orientation, err := getJpegOrientation(outputFilePath) - if err != nil { - return nil, fmt.Errorf("failed to get jpeg orientation: %w", err) - } - if orientation == 1 { - return result, nil - } - - if err = rotateImg(result.Path, orientation); err != nil { - return nil, fmt.Errorf("failed to rotate image: %w", err) - } - return result, nil -} - -func rotateImg(filePath string, orientation int) error { - resultImg, err := os.OpenFile(filePath, os.O_RDWR, 0777) - if err != nil { - return err - } - defer func() { _ = resultImg.Close() }() - - imgFlag := make([]byte, 3) - if _, err = io.ReadFull(resultImg, imgFlag); err != nil { - return err - } - if _, err = resultImg.Seek(0, 0); err != nil { - return err - } - - var img image.Image - if bytes.Equal(imgFlag, []byte{0xFF, 0xD8, 0xFF}) { - img, err = jpeg.Decode(resultImg) - } else { - img, err = png.Decode(resultImg) - } - if err != nil { - return err - } - - switch orientation { - case 8: - img = rotate90(img) - case 3: - img = rotate90(rotate90(img)) - case 6: - img = rotate90(rotate90(rotate90(img))) - case 2: - img = mirrorImg(img) - case 7: - img = rotate90(mirrorImg(img)) - case 4: - img = rotate90(rotate90(mirrorImg(img))) - case 5: - img = rotate90(rotate90(rotate90(mirrorImg(img)))) - } - - if err = resultImg.Truncate(0); err != nil { - return err - } - if _, err = resultImg.Seek(0, 0); err != nil { - return err - } - - if bytes.Equal(imgFlag, []byte{0xFF, 0xD8, 0xFF}) { - return jpeg.Encode(resultImg, img, nil) - } - return png.Encode(resultImg, img) -} - -func getJpegOrientation(fileName string) (int, error) { - f, err := os.Open(fileName) - if err != nil { - return 0, err - } - defer func() { _ = f.Close() }() - - header := make([]byte, 6) - defer func() { header = nil }() - if _, err = io.ReadFull(f, header); err != nil { - return 0, err - } - - // jpeg format header - if !bytes.Equal(header[:3], []byte{0xFF, 0xD8, 0xFF}) { - return 0, errors.New("not a jpeg") - } - - // not a APP1 marker - if header[3] != 0xE1 { - return 1, nil - } - - // exif data total length - totalLen := int(header[4])<<8 + int(header[5]) - 2 - buf := make([]byte, totalLen) - defer func() { buf = nil }() - if _, err = io.ReadFull(f, buf); err != nil { - return 0, err - } - - // remove Exif identifier code - buf = buf[6:] - - // byte order - parse16, parse32, err := initParseMethod(buf[:2]) - if err != nil { - return 0, err - } - - // version - _ = buf[2:4] - - // first IFD offset - offset := parse32(buf[4:8]) - - // first DE offset - offset += 2 - buf = buf[offset:] - - const ( - orientationTag = 0x112 - deEntryLength = 12 - ) - for len(buf) > deEntryLength { - tag := parse16(buf[:2]) - if tag == orientationTag { - return int(parse32(buf[8:12])), nil - } - buf = buf[deEntryLength:] - } - - return 0, errors.New("orientation not found") -} - -func initParseMethod(buf []byte) (func([]byte) int16, func([]byte) int32, error) { - if bytes.Equal(buf, []byte{0x49, 0x49}) { - return littleEndian16, littleEndian32, nil - } - if bytes.Equal(buf, []byte{0x4D, 0x4D}) { - return bigEndian16, bigEndian32, nil - } - return nil, nil, errors.New("invalid byte order") -} - -func littleEndian16(buf []byte) int16 { - return int16(buf[0]) | int16(buf[1])<<8 -} - -func bigEndian16(buf []byte) int16 { - return int16(buf[1]) | int16(buf[0])<<8 -} - -func littleEndian32(buf []byte) int32 { - return int32(buf[0]) | int32(buf[1])<<8 | int32(buf[2])<<16 | int32(buf[3])<<24 -} - -func bigEndian32(buf []byte) int32 { - return int32(buf[3]) | int32(buf[2])<<8 | int32(buf[1])<<16 | int32(buf[0])<<24 -} - -func rotate90(img image.Image) image.Image { - bounds := img.Bounds() - width, height := bounds.Dx(), bounds.Dy() - newImg := image.NewRGBA(image.Rect(0, 0, height, width)) - for x := 0; x < width; x++ { - for y := 0; y < height; y++ { - newImg.Set(y, width-x-1, img.At(x, y)) - } - } - return newImg -} - -func mirrorImg(img image.Image) image.Image { - bounds := img.Bounds() - width, height := bounds.Dx(), bounds.Dy() - newImg := image.NewRGBA(image.Rect(0, 0, width, height)) - for x := 0; x < width; x++ { - for y := 0; y < height; y++ { - newImg.Set(width-x-1, y, img.At(x, y)) - } - } - return newImg -} - -func (f *LibRawGenerator) Priority() int { - return 250 -} - -func (f *LibRawGenerator) EnableFlag() string { - return "thumb_libraw_enabled" -} - -var _ Generator = (*LibRawGenerator)(nil) diff --git a/pkg/thumb/libreoffice.go b/pkg/thumb/libreoffice.go index 74926134..609fade2 100644 --- a/pkg/thumb/libreoffice.go +++ b/pkg/thumb/libreoffice.go @@ -10,51 +10,46 @@ import ( "path/filepath" "strings" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gofrs/uuid" ) -func init() { - RegisterGenerator(&LibreOfficeGenerator{}) +func NewLibreOfficeGenerator(l logging.Logger, settings setting.Provider) *LibreOfficeGenerator { + return &LibreOfficeGenerator{l: l, settings: settings} } type LibreOfficeGenerator struct { - exts []string - lastRawExts string + settings setting.Provider + l logging.Logger } -func (l *LibreOfficeGenerator) Generate(ctx context.Context, file io.Reader, src string, name string, options map[string]string) (*Result, error) { - const ( - thumbLibreOfficePath = "thumb_libreoffice_path" - thumbLibreOfficeExts = "thumb_libreoffice_exts" - thumbEncodeMethod = "thumb_encode_method" - tempPath = "temp_path" - ) - sofficeOpts := model.GetSettingByNames(thumbLibreOfficePath, thumbLibreOfficeExts, thumbEncodeMethod, tempPath) - - if l.lastRawExts != sofficeOpts[thumbLibreOfficeExts] { - l.exts = strings.Split(sofficeOpts[thumbLibreOfficeExts], ",") - l.lastRawExts = sofficeOpts[thumbLibreOfficeExts] +func (l *LibreOfficeGenerator) Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) { + if !util.IsInExtensionListExt(l.settings.LibreOfficeThumbExts(ctx), ext) { + return nil, fmt.Errorf("unsupported video format: %w", ErrPassThrough) } - if !util.IsInExtensionList(l.exts, name) { - return nil, fmt.Errorf("unsupported document format: %w", ErrPassThrough) + if es.Entity().Size() > l.settings.LibreOfficeThumbMaxSize(ctx) { + return nil, fmt.Errorf("file is too big: %w", ErrPassThrough) } tempOutputPath := filepath.Join( - util.RelativePath(sofficeOpts[tempPath]), - "thumb", + util.DataPath(l.settings.TempPath(ctx)), + thumbTempFolder, fmt.Sprintf("soffice_%s", uuid.Must(uuid.NewV4()).String()), ) - tempInputPath := src - if tempInputPath == "" { + tempInputPath := "" + if es.IsLocal() { + tempInputPath = es.LocalPath(ctx) + } else { // If not local policy files, download to temp folder tempInputPath = filepath.Join( - util.RelativePath(sofficeOpts[tempPath]), + util.DataPath(l.settings.TempPath(ctx)), "thumb", - fmt.Sprintf("soffice_%s%s", uuid.Must(uuid.NewV4()).String(), filepath.Ext(name)), + fmt.Sprintf("soffice_%s.%s", uuid.Must(uuid.NewV4()).String(), ext), ) // Due to limitations of ffmpeg, we need to write the input file to disk first @@ -66,32 +61,32 @@ func (l *LibreOfficeGenerator) Generate(ctx context.Context, file io.Reader, src defer os.Remove(tempInputPath) defer tempInputFile.Close() - if _, err = io.Copy(tempInputFile, file); err != nil { - return nil, fmt.Errorf("failed to write input file: %w", err) + if _, err = io.Copy(tempInputFile, es); err != nil { + return &Result{Path: tempOutputPath}, fmt.Errorf("failed to write input file: %w", err) } tempInputFile.Close() } // Convert the document to an image - cmd := exec.CommandContext(ctx, sofficeOpts[thumbLibreOfficePath], "--headless", + encode := l.settings.ThumbEncode(ctx) + cmd := exec.CommandContext(ctx, l.settings.LibreOfficePath(ctx), "--headless", "-nologo", "--nofirststartwizard", "--invisible", "--norestore", "--convert-to", - sofficeOpts[thumbEncodeMethod], "--outdir", tempOutputPath, tempInputPath) + encode.Format, "--outdir", tempOutputPath, tempInputPath) // Redirect IO var stdErr bytes.Buffer - cmd.Stdin = file cmd.Stderr = &stdErr if err := cmd.Run(); err != nil { - util.Log().Warning("Failed to invoke LibreOffice: %s", stdErr.String()) - return nil, fmt.Errorf("failed to invoke LibreOffice: %w", err) + l.l.Warning("Failed to invoke LibreOffice: %s", stdErr.String()) + return &Result{Path: tempOutputPath}, fmt.Errorf("failed to invoke LibreOffice: %w, raw output: %s", err, stdErr.String()) } return &Result{ Path: filepath.Join( tempOutputPath, - strings.TrimSuffix(filepath.Base(tempInputPath), filepath.Ext(tempInputPath))+"."+sofficeOpts[thumbEncodeMethod], + strings.TrimSuffix(filepath.Base(tempInputPath), filepath.Ext(tempInputPath))+"."+encode.Format, ), Continue: true, Cleanup: []func(){func() { _ = os.RemoveAll(tempOutputPath) }}, @@ -102,6 +97,6 @@ func (l *LibreOfficeGenerator) Priority() int { return 50 } -func (l *LibreOfficeGenerator) EnableFlag() string { - return "thumb_libreoffice_enabled" +func (l *LibreOfficeGenerator) Enabled(ctx context.Context) bool { + return l.settings.LibreOfficeThumbGeneratorEnabled(ctx) } diff --git a/pkg/thumb/music.go b/pkg/thumb/music.go new file mode 100644 index 00000000..4947336c --- /dev/null +++ b/pkg/thumb/music.go @@ -0,0 +1,79 @@ +package thumb + +import ( + "context" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/dhowden/tag" + "github.com/gofrs/uuid" + "os" + "path/filepath" +) + +func NewMusicCoverGenerator(l logging.Logger, settings setting.Provider) *MusicCoverGenerator { + return &MusicCoverGenerator{l: l, settings: settings} +} + +type MusicCoverGenerator struct { + l logging.Logger + settings setting.Provider +} + +func (v *MusicCoverGenerator) Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) { + if !util.IsInExtensionListExt(v.settings.MusicCoverThumbExts(ctx), ext) { + return nil, fmt.Errorf("unsupported music format: %w", ErrPassThrough) + } + + if es.Entity().Size() > v.settings.MusicCoverThumbMaxSize(ctx) { + return nil, fmt.Errorf("file is too big: %w", ErrPassThrough) + } + + m, err := tag.ReadFrom(es) + if err != nil { + return nil, fmt.Errorf("faield to read audio tags from file: %w", err) + } + + p := m.Picture() + if p == nil || len(p.Data) == 0 { + return nil, fmt.Errorf("no cover found in given file") + } + + thumbExt := ".jpg" + if p.Ext != "" { + thumbExt = p.Ext + } + + tempPath := filepath.Join( + util.DataPath(v.settings.TempPath(ctx)), + thumbTempFolder, + fmt.Sprintf("thumb_%s.%s", uuid.Must(uuid.NewV4()).String(), thumbExt), + ) + + thumbFile, err := util.CreatNestedFile(tempPath) + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + + defer thumbFile.Close() + + if _, err := thumbFile.Write(p.Data); err != nil { + return &Result{Path: tempPath}, fmt.Errorf("failed to write cover to file: %w", err) + } + + return &Result{ + Path: tempPath, + Continue: true, + Cleanup: []func(){func() { _ = os.Remove(tempPath) }}, + }, nil +} + +func (v *MusicCoverGenerator) Priority() int { + return 50 +} + +func (v *MusicCoverGenerator) Enabled(ctx context.Context) bool { + return v.settings.MusicCoverThumbGeneratorEnabled(ctx) +} diff --git a/pkg/thumb/pipeline.go b/pkg/thumb/pipeline.go index 8ea1fd5d..bf9bbab6 100644 --- a/pkg/thumb/pipeline.go +++ b/pkg/thumb/pipeline.go @@ -4,93 +4,114 @@ import ( "context" "errors" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "io" - "os" - "path/filepath" "reflect" "sort" - "strconv" ) -// Generator generates a thumbnail for a given reader. -type Generator interface { - // Generate generates a thumbnail for a given reader. Src is the original file path, only provided - // for local policy files. - Generate(ctx context.Context, file io.Reader, src string, name string, options map[string]string) (*Result, error) - - // Priority of execution order, smaller value means higher priority. - Priority() int - - // EnableFlag returns the setting name to enable this generator. - EnableFlag() string -} - -type Result struct { - Path string - Continue bool - Cleanup []func() -} - type ( + // Generator generates a thumbnail for a given reader. + Generator interface { + // Generate generates a thumbnail for a given reader. Src is the original file path, only provided + // for local policy files. State is the result from previous generators, and can be read by current + // generator for intermedia result. + Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) + + // Priority of execution order, smaller value means higher priority. + Priority() int + + // Enabled returns if current generator is enabled. + Enabled(ctx context.Context) bool + } + Result struct { + Path string + Ext string + Continue bool + Cleanup []func() + } GeneratorType string - GeneratorList []Generator + + generatorList []Generator + pipeline struct { + generators generatorList + settings setting.Provider + l logging.Logger + } ) var ( - Generators = GeneratorList{} - ErrPassThrough = errors.New("pass through") ErrNotAvailable = fmt.Errorf("thumbnail not available: %w", ErrPassThrough) ) -func (g GeneratorList) Len() int { +func (g generatorList) Len() int { return len(g) } -func (g GeneratorList) Less(i, j int) bool { +func (g generatorList) Less(i, j int) bool { return g[i].Priority() < g[j].Priority() } -func (g GeneratorList) Swap(i, j int) { +func (g generatorList) Swap(i, j int) { g[i], g[j] = g[j], g[i] } -// RegisterGenerator registers a thumbnail generator. -func RegisterGenerator(generator Generator) { - Generators = append(Generators, generator) - sort.Sort(Generators) +// NewPipeline creates a new pipeline with all available generators. +func NewPipeline(settings setting.Provider, l logging.Logger) Generator { + generators := generatorList{} + generators = append( + generators, + NewBuiltinGenerator(settings), + NewFfmpegGenerator(l, settings), + NewVipsGenerator(l, settings), + NewLibreOfficeGenerator(l, settings), + NewMusicCoverGenerator(l, settings), + ) + sort.Sort(generators) + + return pipeline{ + generators: generators, + settings: settings, + l: l, + } } -func (p GeneratorList) Generate(ctx context.Context, file io.Reader, src, name string, options map[string]string) (*Result, error) { - inputFile, inputSrc, inputName := file, src, name - for _, generator := range p { - if model.IsTrueVal(options[generator.EnableFlag()]) { - res, err := generator.Generate(ctx, inputFile, inputSrc, inputName, options) +func (p pipeline) Generate(ctx context.Context, es entitysource.EntitySource, ext string, state *Result) (*Result, error) { + e := es.Entity() + for _, generator := range p.generators { + if generator.Enabled(ctx) { + if _, err := es.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("thumb: failed to seek to start of file: %w", err) + } + + res, err := generator.Generate(ctx, es, ext, state) if errors.Is(err, ErrPassThrough) { - util.Log().Debug("Failed to generate thumbnail using %s for %s: %s, passing through to next generator.", reflect.TypeOf(generator).String(), name, err) + p.l.Debug("Failed to generate thumbnail using %s for %s: %s, passing through to next generator.", reflect.TypeOf(generator).String(), e.Source(), err) continue } if res != nil && res.Continue { - util.Log().Debug("Generator %s for %s returned continue, passing through to next generator.", reflect.TypeOf(generator).String(), name) + p.l.Debug("Generator %s for %s returned continue, passing through to next generator.", reflect.TypeOf(generator).String(), e.Source()) - // defer cleanup funcs + // defer cleanup functions for _, cleanup := range res.Cleanup { defer cleanup() } // prepare file reader for next generator - intermediate, err := os.Open(res.Path) + state = res + es, err = es.CloneToLocalSrc(types.EntityTypeVersion, res.Path) if err != nil { - return nil, fmt.Errorf("failed to open intermediate thumb file: %w", err) + return nil, fmt.Errorf("thumb: failed to clone to local source: %w", err) } - defer intermediate.Close() - inputFile = intermediate - inputSrc = res.Path - inputName = filepath.Base(res.Path) + defer es.Close() + ext = util.Ext(res.Path) continue } @@ -100,23 +121,10 @@ func (p GeneratorList) Generate(ctx context.Context, file io.Reader, src, name s return nil, ErrNotAvailable } -func (p GeneratorList) Priority() int { +func (p pipeline) Priority() int { return 0 } -func (p GeneratorList) EnableFlag() string { - return "" -} - -func thumbSize(options map[string]string) (uint, uint) { - w, h := uint(400), uint(300) - if wParsed, err := strconv.Atoi(options["thumb_width"]); err == nil { - w = uint(wParsed) - } - - if hParsed, err := strconv.Atoi(options["thumb_height"]); err == nil { - h = uint(hParsed) - } - - return w, h +func (p pipeline) Enabled(ctx context.Context) bool { + return true } diff --git a/pkg/thumb/tester.go b/pkg/thumb/tester.go index 1b9204f4..6439c6aa 100644 --- a/pkg/thumb/tester.go +++ b/pkg/thumb/tester.go @@ -23,13 +23,28 @@ func TestGenerator(ctx context.Context, name, executable string) (string, error) return testFfmpegGenerator(ctx, executable) case "libreOffice": return testLibreOfficeGenerator(ctx, executable) - case "libRaw": - return testLibRawGenerator(ctx, executable) + case "ffprobe": + return testFFProbeGenerator(ctx, executable) default: return "", ErrUnknownGenerator } } +func testFFProbeGenerator(ctx context.Context, executable string) (string, error) { + cmd := exec.CommandContext(ctx, executable, "-version") + var output bytes.Buffer + cmd.Stdout = &output + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to invoke ffmpeg executable: %w", err) + } + + if !strings.Contains(output.String(), "ffprobe") { + return "", ErrUnknownOutput + } + + return output.String(), nil +} + func testVipsGenerator(ctx context.Context, executable string) (string, error) { cmd := exec.CommandContext(ctx, executable, "--version") var output bytes.Buffer @@ -74,18 +89,3 @@ func testLibreOfficeGenerator(ctx context.Context, executable string) (string, e return output.String(), nil } - -func testLibRawGenerator(ctx context.Context, executable string) (string, error) { - cmd := exec.CommandContext(ctx, executable) - var output bytes.Buffer - cmd.Stdout = &output - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("failed to invoke libraw executable: %w", err) - } - - if !strings.Contains(output.String(), "LibRaw") { - return "", ErrUnknownOutput - } - - return output.String(), nil -} diff --git a/pkg/thumb/vips.go b/pkg/thumb/vips.go index b59c13a0..2e1001ef 100644 --- a/pkg/thumb/vips.go +++ b/pkg/thumb/vips.go @@ -5,59 +5,90 @@ import ( "context" "fmt" "io" + "os" "os/exec" "path/filepath" - "strings" + "runtime" + "strconv" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gofrs/uuid" ) -func init() { - RegisterGenerator(&VipsGenerator{}) +func NewVipsGenerator(l logging.Logger, settings setting.Provider) *VipsGenerator { + return &VipsGenerator{l: l, settings: settings} } type VipsGenerator struct { - exts []string - lastRawExts string + l logging.Logger + settings setting.Provider } -func (v *VipsGenerator) Generate(ctx context.Context, file io.Reader, src, name string, options map[string]string) (*Result, error) { - const ( - thumbVipsPath = "thumb_vips_path" - thumbVipsExts = "thumb_vips_exts" - thumbEncodeQuality = "thumb_encode_quality" - thumbEncodeMethod = "thumb_encode_method" - tempPath = "temp_path" - ) - vipsOpts := model.GetSettingByNames(thumbVipsPath, thumbVipsExts, thumbEncodeQuality, thumbEncodeMethod, tempPath) - - if v.lastRawExts != vipsOpts[thumbVipsExts] { - v.exts = strings.Split(vipsOpts[thumbVipsExts], ",") - v.lastRawExts = vipsOpts[thumbVipsExts] +func (v *VipsGenerator) Generate(ctx context.Context, es entitysource.EntitySource, ext string, previous *Result) (*Result, error) { + if !util.IsInExtensionListExt(v.settings.VipsThumbExts(ctx), ext) { + return nil, fmt.Errorf("unsupported video format: %w", ErrPassThrough) } - if !util.IsInExtensionList(v.exts, name) { - return nil, fmt.Errorf("unsupported image format: %w", ErrPassThrough) + if es.Entity().Size() > v.settings.VipsThumbMaxSize(ctx) { + return nil, fmt.Errorf("file is too big: %w", ErrPassThrough) } outputOpt := ".png" - if vipsOpts[thumbEncodeMethod] == "jpg" { - outputOpt = fmt.Sprintf(".jpg[Q=%s]", vipsOpts[thumbEncodeQuality]) + encode := v.settings.ThumbEncode(ctx) + if encode.Format == "jpg" { + outputOpt = fmt.Sprintf(".jpg[Q=%d]", encode.Quality) } + input := "[descriptor=0]" + usePipe := true + if runtime.GOOS == "windows" { + // Pipe IO is not working on Windows for VIPS + if es.IsLocal() { + // escape [ and ] in file name + input = fmt.Sprintf("[filename=\"%s\"]", es.LocalPath(ctx)) + usePipe = false + } else { + usePipe = false + // If not local policy files, download to temp folder + tempPath := filepath.Join( + util.DataPath(v.settings.TempPath(ctx)), + "thumb", + fmt.Sprintf("vips_%s.%s", uuid.Must(uuid.NewV4()).String(), ext), + ) + input = fmt.Sprintf("[filename=\"%s\"]", tempPath) + + // Due to limitations of ffmpeg, we need to write the input file to disk first + tempInputFile, err := util.CreatNestedFile(tempPath) + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + + defer os.Remove(tempPath) + defer tempInputFile.Close() + + if _, err = io.Copy(tempInputFile, es); err != nil { + return &Result{Path: tempPath}, fmt.Errorf("failed to write input file: %w", err) + } + + tempInputFile.Close() + } + } + + w, h := v.settings.ThumbSize(ctx) cmd := exec.CommandContext(ctx, - vipsOpts[thumbVipsPath], "thumbnail_source", "[descriptor=0]", outputOpt, options["thumb_width"], - "--height", options["thumb_height"]) + v.settings.VipsPath(ctx), "thumbnail_source", input, outputOpt, strconv.Itoa(w), + "--height", strconv.Itoa(h)) - outTempPath := filepath.Join( - util.RelativePath(vipsOpts[tempPath]), - "thumb", + tempPath := filepath.Join( + util.DataPath(v.settings.TempPath(ctx)), + thumbTempFolder, fmt.Sprintf("thumb_%s", uuid.Must(uuid.NewV4()).String()), ) - thumbFile, err := util.CreatNestedFile(outTempPath) + thumbFile, err := util.CreatNestedFile(tempPath) if err != nil { return nil, fmt.Errorf("failed to create temp file: %w", err) } @@ -66,22 +97,24 @@ func (v *VipsGenerator) Generate(ctx context.Context, file io.Reader, src, name // Redirect IO var vipsErr bytes.Buffer - cmd.Stdin = file + if usePipe { + cmd.Stdin = es + } cmd.Stdout = thumbFile cmd.Stderr = &vipsErr if err := cmd.Run(); err != nil { - util.Log().Warning("Failed to invoke vips: %s", vipsErr.String()) - return nil, fmt.Errorf("failed to invoke vips: %w", err) + v.l.Warning("Failed to invoke vips: %s", vipsErr.String()) + return &Result{Path: tempPath}, fmt.Errorf("failed to invoke vips: %w, raw output: %s", err, vipsErr.String()) } - return &Result{Path: outTempPath}, nil + return &Result{Path: tempPath}, nil } func (v *VipsGenerator) Priority() int { return 100 } -func (v *VipsGenerator) EnableFlag() string { - return "thumb_vips_enabled" +func (v *VipsGenerator) Enabled(ctx context.Context) bool { + return v.settings.VipsThumbGeneratorEnabled(ctx) } diff --git a/pkg/util/common.go b/pkg/util/common.go index fe1fa91f..a93472dd 100644 --- a/pkg/util/common.go +++ b/pkg/util/common.go @@ -1,24 +1,39 @@ package util import ( + "context" + "fmt" + "github.com/gin-gonic/gin" "math/rand" - "path/filepath" "regexp" "strings" "time" + "unicode/utf8" ) func init() { rand.Seed(time.Now().UnixNano()) } +var ( + RandomVariantAll = []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + RandomLowerCases = []rune("1234567890abcdefghijklmnopqrstuvwxyz") +) + // RandStringRunes 返回随机字符串 func RandStringRunes(n int) string { - var letterRunes = []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, n) + for i := range b { + b[i] = RandomVariantAll[rand.Intn(len(RandomVariantAll))] + } + return string(b) +} +// RandString returns random string in given length and variant +func RandString(n int, variant []rune) string { b := make([]rune, n) for i := range b { - b[i] = letterRunes[rand.Intn(len(letterRunes))] + b[i] = variant[rand.Intn(len(variant))] } return string(b) } @@ -35,13 +50,27 @@ func ContainsUint(s []uint, e uint) bool { // IsInExtensionList 返回文件的扩展名是否在给定的列表范围内 func IsInExtensionList(extList []string, fileName string) bool { - ext := strings.ToLower(filepath.Ext(fileName)) + ext := Ext(fileName) + // 无扩展名时 + if len(ext) == 0 { + return false + } + + if ContainsString(extList, ext) { + return true + } + + return false +} + +// IsInExtensionList 返回文件的扩展名是否在给定的列表范围内 +func IsInExtensionListExt(extList []string, ext string) bool { // 无扩展名时 if len(ext) == 0 { return false } - if ContainsString(extList, ext[1:]) { + if ContainsString(extList, ext) { return true } @@ -122,3 +151,115 @@ func SliceDifference(slice1, slice2 []string) []string { } return nn } + +// WithValue inject key-value pair into request context. +func WithValue(c *gin.Context, key any, value any) { + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), key, value)) +} + +// BoolToString transform bool to string +func BoolToString(b bool) string { + if b { + return "1" + } + return "0" +} + +func EncodeTimeFlowString(str string, timeNow int64) string { + timeNow = timeNow / 1000 + timeDigits := []int{} + timeDigitIndex := 0 + + if len(str) == 0 { + return "" + } + + str = fmt.Sprintf("%d|%s", timeNow, str) + + res := make([]int32, 0, utf8.RuneCountInString(str)) + for timeNow > 0 { + timeDigits = append(timeDigits, int(timeNow%int64(10))) + timeNow = timeNow / 10 + } + + add := false + for pos, rune := range str { + // take single digit with index timeDigitIndex from timeNow + newIndex := pos + if add { + newIndex = pos + timeDigits[timeDigitIndex]*timeDigitIndex + } else { + newIndex = 2*timeDigitIndex*timeDigits[timeDigitIndex] - pos + } + + if newIndex < 0 { + newIndex = newIndex * -1 + } + + res = append(res, rune) + newIndex = newIndex % len(res) + + res[newIndex], res[len(res)-1] = res[len(res)-1], res[newIndex] + + add = !add + // Add timeDigitIndex by 1, but does not exceed total digits in timeNow + timeDigitIndex++ + timeDigitIndex = timeDigitIndex % len(timeDigits) + } + + return string(res) +} + +func DecodeTimeFlowStringTime(str string, timeNow int64) string { + timeNow = timeNow / 1000 + timeDigits := []int{} + + if len(str) == 0 { + return "" + } + + for timeNow > 0 { + timeDigits = append(timeDigits, int(timeNow%int64(10))) + timeNow = timeNow / 10 + } + + res := make([]int32, utf8.RuneCountInString(str)) + secret := []rune(str) + add := false + if len(secret)%2 == 0 { + add = true + } + timeDigitIndex := (len(secret) - 1) % len(timeDigits) + for pos := range secret { + // take single digit with index timeDigitIndex from timeNow + newIndex := len(res) - 1 - pos + if add { + newIndex = newIndex + timeDigits[timeDigitIndex]*timeDigitIndex + } else { + newIndex = 2*timeDigitIndex*timeDigits[timeDigitIndex] - newIndex + } + + if newIndex < 0 { + newIndex = newIndex * -1 + } + + newIndex = newIndex % len(secret) + + res[len(res)-1-pos] = secret[newIndex] + secret[newIndex], secret[len(res)-1-pos] = secret[len(res)-1-pos], secret[newIndex] + secret = secret[:len(secret)-1] + + add = !add + // Add timeDigitIndex by 1, but does not exceed total digits in timeNow + timeDigitIndex-- + if timeDigitIndex < 0 { + timeDigitIndex = len(timeDigits) - 1 + } + } + + return string(res) +} + +func ToPtr[T any](v T) *T { + return &v +} diff --git a/pkg/util/io.go b/pkg/util/io.go index fe3bd9a9..24a91742 100644 --- a/pkg/util/io.go +++ b/pkg/util/io.go @@ -22,7 +22,6 @@ func CreatNestedFile(path string) (*os.File, error) { if !Exists(basePath) { err := os.MkdirAll(basePath, 0700) if err != nil { - Log().Warning("Failed to create directory: %s", err) return nil, err } } @@ -30,6 +29,19 @@ func CreatNestedFile(path string) (*os.File, error) { return os.Create(path) } +// CreatNestedFolder creates a folder with the given path, if the directory does not exist, +// it will be created recursively. +func CreatNestedFolder(path string) error { + if !Exists(path) { + err := os.MkdirAll(path, 0700) + if err != nil { + return err + } + } + + return nil +} + // IsEmpty 返回给定目录是否为空目录 func IsEmpty(name string) (bool, error) { f, err := os.Open(name) @@ -44,3 +56,21 @@ func IsEmpty(name string) (bool, error) { } return false, err // Either not empty or error, suits both cases } + +type CallbackReader struct { + reader io.Reader + callback func(int64) +} + +func NewCallbackReader(reader io.Reader, callback func(int64)) *CallbackReader { + return &CallbackReader{ + reader: reader, + callback: callback, + } +} + +func (r *CallbackReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.callback(int64(n)) + return +} diff --git a/pkg/util/logger.go b/pkg/util/logger.go index 107ec718..df2d09ee 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -116,7 +116,7 @@ func (ll *Logger) Debug(format string, v ...interface{}) { // return // } // msg := fmt.Sprintf("[SQL] %s", v...) -// ll.Println(msg) +// ll.println(msg) //} // BuildLogger 构建logger diff --git a/pkg/util/path.go b/pkg/util/path.go index 2dd8aefe..257191ae 100644 --- a/pkg/util/path.go +++ b/pkg/util/path.go @@ -1,12 +1,19 @@ package util import ( + "context" "os" "path" "path/filepath" "strings" ) +const ( + DataFolder = "data" +) + +var UseWorkingDir = false + // DotPathToStandardPath 将","分割的路径转换为标准路径 func DotPathToStandardPath(path string) string { return "/" + strings.Replace(path, ",", "/", -1) @@ -50,6 +57,10 @@ func FormSlash(old string) string { // RelativePath 获取相对可执行文件的路径 func RelativePath(name string) string { + if UseWorkingDir { + return name + } + if filepath.IsAbs(name) { return name } @@ -57,3 +68,41 @@ func RelativePath(name string) string { return filepath.Join(filepath.Dir(e), name) } +// DataPath relative path for store persist data file +func DataPath(child string) string { + dataPath := RelativePath(DataFolder) + if !Exists(dataPath) { + os.MkdirAll(dataPath, 0700) + } + + if filepath.IsAbs(child) { + return child + } + + return filepath.Join(dataPath, child) +} + +// MkdirIfNotExist create directory if not exist +func MkdirIfNotExist(ctx context.Context, p string) { + if !Exists(p) { + os.MkdirAll(p, 0700) + } +} + +// SlashClean is equivalent to but slightly more efficient than +// path.Clean("/" + name). +func SlashClean(name string) string { + if name == "" || name[0] != '/' { + name = "/" + name + } + return path.Clean(name) +} + +// Ext returns the file name extension used by path, without the dot. +func Ext(name string) string { + ext := strings.ToLower(filepath.Ext(name)) + if len(ext) > 0 { + ext = ext[1:] + } + return ext +} diff --git a/pkg/util/session.go b/pkg/util/session.go index 705eee19..0b5825d1 100644 --- a/pkg/util/session.go +++ b/pkg/util/session.go @@ -19,13 +19,13 @@ func SetSession(c *gin.Context, list map[string]interface{}) { } // GetSession 获取session -func GetSession(c *gin.Context, key string) interface{} { +func GetSession(c *gin.Context, key any) interface{} { s := sessions.Default(c) return s.Get(key) } // DeleteSession 删除session -func DeleteSession(c *gin.Context, key string) { +func DeleteSession(c *gin.Context, key any) { s := sessions.Default(c) s.Delete(key) s.Save() diff --git a/pkg/webdav/file.go b/pkg/webdav/file.go index a0e589b6..3bbb62cb 100644 --- a/pkg/webdav/file.go +++ b/pkg/webdav/file.go @@ -5,16 +5,7 @@ package webdav import ( - "context" - "net/http" "path" - "path/filepath" - "strconv" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" ) // slashClean is equivalent to but slightly more efficient than @@ -25,182 +16,3 @@ func slashClean(name string) string { } return path.Clean(name) } - -// 更新Copy或Move后的修改时间 -func updateCopyMoveModtime(req *http.Request, fs *filesystem.FileSystem, dst string) error { - var modtime time.Time - if timeVal := req.Header.Get("X-OC-Mtime"); timeVal != "" { - timeUnix, err := strconv.ParseInt(timeVal, 10, 64) - if err == nil { - modtime = time.Unix(timeUnix, 0) - } - } - - if modtime.IsZero() { - return nil - } - - ok, fi := isPathExist(req.Context(), fs, dst) - if !ok { - return nil - } - - if fi.IsDir() { - return model.DB.Model(fi.(*model.Folder)).UpdateColumn("updated_at", modtime).Error - } - return model.DB.Model(fi.(*model.File)).UpdateColumn("updated_at", modtime).Error -} - -// moveFiles moves files and/or directories from src to dst. -// -// See section 9.9.4 for when various HTTP status codes apply. -func moveFiles(ctx context.Context, fs *filesystem.FileSystem, src FileInfo, dst string, overwrite bool) (status int, err error) { - - var ( - fileIDs []uint - folderIDs []uint - ) - if src.IsDir() { - folderIDs = []uint{src.(*model.Folder).ID} - } else { - fileIDs = []uint{src.(*model.File).ID} - } - - if overwrite { - if err := _checkOverwriteFile(ctx, fs, src, dst); err != nil { - return http.StatusInternalServerError, err - } - } - - // 判断是否需要移动 - if src.GetPosition() != path.Dir(dst) { - err = fs.Move( - context.WithValue(ctx, fsctx.WebdavDstName, path.Base(dst)), - folderIDs, - fileIDs, - src.GetPosition(), - path.Dir(dst), - ) - } else if src.GetName() != path.Base(dst) { - // 判断是否需要重命名 - err = fs.Rename( - ctx, - folderIDs, - fileIDs, - path.Base(dst), - ) - } - - if err != nil { - return http.StatusInternalServerError, err - } - return http.StatusNoContent, nil -} - -// copyFiles copies files and/or directories from src to dst. -// -// See section 9.8.5 for when various HTTP status codes apply. -func copyFiles(ctx context.Context, fs *filesystem.FileSystem, src FileInfo, dst string, overwrite bool, depth int, recursion int) (status int, err error) { - if recursion == 1000 { - return http.StatusInternalServerError, errRecursionTooDeep - } - recursion++ - - var ( - fileIDs []uint - folderIDs []uint - ) - - if overwrite { - if err := _checkOverwriteFile(ctx, fs, src, dst); err != nil { - return http.StatusInternalServerError, err - } - } - - if src.IsDir() { - folderIDs = []uint{src.(*model.Folder).ID} - } else { - fileIDs = []uint{src.(*model.File).ID} - } - - err = fs.Copy( - context.WithValue(ctx, fsctx.WebdavDstName, path.Base(dst)), - folderIDs, - fileIDs, - src.GetPosition(), - path.Dir(dst), - ) - if err != nil { - return http.StatusInternalServerError, err - } - - return http.StatusNoContent, nil -} - -// 判断目标 文件/夹 是否已经存在,存在则先删除目标文件/夹 -func _checkOverwriteFile(ctx context.Context, fs *filesystem.FileSystem, src FileInfo, dst string) error { - if src.IsDir() { - ok, folder := fs.IsPathExist(dst) - if ok { - return fs.Delete(ctx, []uint{folder.ID}, []uint{}, false, false) - } - } else { - ok, file := fs.IsFileExist(dst) - if ok { - return fs.Delete(ctx, []uint{}, []uint{file.ID}, false, false) - } - } - return nil -} - -// walkFS traverses filesystem fs starting at name up to depth levels. -// -// Allowed values for depth are 0, 1 or infiniteDepth. For each visited node, -// walkFS calls walkFn. If a visited file system node is a directory and -// walkFn returns filepath.SkipDir, walkFS will skip traversal of this node. -func walkFS( - ctx context.Context, - fs *filesystem.FileSystem, - depth int, - name string, - info FileInfo, - walkFn func(reqPath string, info FileInfo, err error) error) error { - // This implementation is based on Walk's code in the standard path/filepath package. - err := walkFn(name, info, nil) - if err != nil { - if info.IsDir() && err == filepath.SkipDir { - return nil - } - return err - } - if !info.IsDir() || depth == 0 { - return nil - } - if depth == 1 { - depth = 0 - } - - dirs, _ := info.(*model.Folder).GetChildFolder() - files, _ := info.(*model.Folder).GetChildFiles() - - for _, fileInfo := range files { - filename := path.Join(name, fileInfo.Name) - err = walkFS(ctx, fs, depth, filename, &fileInfo, walkFn) - if err != nil { - if !fileInfo.IsDir() || err != filepath.SkipDir { - return err - } - } - } - - for _, fileInfo := range dirs { - filename := path.Join(name, fileInfo.Name) - err = walkFS(ctx, fs, depth, filename, &fileInfo, walkFn) - if err != nil { - if !fileInfo.IsDir() || err != filepath.SkipDir { - return err - } - } - } - return nil -} diff --git a/pkg/webdav/prop.go b/pkg/webdav/prop.go index fa5e76e7..e7ea8728 100644 --- a/pkg/webdav/prop.go +++ b/pkg/webdav/prop.go @@ -7,83 +7,168 @@ package webdav import ( "bytes" "context" + "encoding/json" "encoding/xml" "errors" "fmt" - "mime" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/gin-gonic/gin" "net/http" - "path/filepath" "strconv" + "strings" "time" +) - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" +//// 实现 webdav.DeadPropsHolder 接口,不能在models.file里面定义 +//func (file *FileDeadProps) DeadProps() (map[xml.Name]Property, error) { +// return map[xml.Name]Property{ +// xml.Name{Space: "http://owncloud.org/ns", Local: "checksums"}: { +// XMLName: xml.Name{ +// Space: "http://owncloud.org/ns", Local: "checksums", +// }, +// InnerXML: []byte("" + file.MetadataSerialized[model.ChecksumMetadataKey] + ""), +// }, +// }, nil +//} +// +//func (file *FileDeadProps) Patch(proppatches []Proppatch) ([]Propstat, error) { +// var ( +// stat Propstat +// err error +// ) +// stat.Status = http.StatusOK +// for _, patch := range proppatches { +// for _, prop := range patch.Props { +// stat.Props = append(stat.Props, Property{XMLName: prop.XMLName}) +// if prop.XMLName.Space == "DAV:" && prop.XMLName.Local == "lastmodified" { +// var modtimeUnix int64 +// modtimeUnix, err = strconv.ParseInt(string(prop.InnerXML), 10, 64) +// if err == nil { +// err = model.DB.Model(file).UpdateColumn("updated_at", time.Unix(modtimeUnix, 0)).Error +// } +// } +// } +// } +// return []Propstat{stat}, err +//} +// +//type FolderDeadProps struct { +// *model.Folder +//} +// +//func (folder *FolderDeadProps) DeadProps() (map[xml.Name]Property, error) { +// return nil, nil +//} +// +//func (folder *FolderDeadProps) Patch(proppatches []Proppatch) ([]Propstat, error) { +// var ( +// stat Propstat +// err error +// ) +// stat.Status = http.StatusOK +// for _, patch := range proppatches { +// for _, prop := range patch.Props { +// stat.Props = append(stat.Props, Property{XMLName: prop.XMLName}) +// if prop.XMLName.Space == "DAV:" && prop.XMLName.Local == "lastmodified" { +// var modtimeUnix int64 +// modtimeUnix, err = strconv.ParseInt(string(prop.InnerXML), 10, 64) +// if err == nil { +// err = model.DB.Model(folder).UpdateColumn("updated_at", time.Unix(modtimeUnix, 0)).Error +// } +// } +// } +// } +// return []Propstat{stat}, err +//} + +const ( + DeadPropsMetadataPrefix = "dav:" + SpaceNameSeparator = "|" ) -type FileDeadProps struct { - *model.File -} +type ( + // DeadPropsStore implements DeadPropsHolder interface with metadata based store. + metadataDeadProps struct { + f fs.File + fm manager.FileManager + } -// 实现 webdav.DeadPropsHolder 接口,不能在models.file里面定义 -func (file *FileDeadProps) DeadProps() (map[xml.Name]Property, error) { - return map[xml.Name]Property{ - xml.Name{Space: "http://owncloud.org/ns", Local: "checksums"}: { - XMLName: xml.Name{ - Space: "http://owncloud.org/ns", Local: "checksums", - }, - InnerXML: []byte("" + file.MetadataSerialized[model.ChecksumMetadataKey] + ""), - }, - }, nil -} + DeadPropsStore struct { + Lang string `json:"l,omitempty"` + InnerXML []byte `json:"i,omitempty"` + } +) -func (file *FileDeadProps) Patch(proppatches []Proppatch) ([]Propstat, error) { - var ( - stat Propstat - err error - ) - stat.Status = http.StatusOK - for _, patch := range proppatches { - for _, prop := range patch.Props { - stat.Props = append(stat.Props, Property{XMLName: prop.XMLName}) - if prop.XMLName.Space == "DAV:" && prop.XMLName.Local == "lastmodified" { - var modtimeUnix int64 - modtimeUnix, err = strconv.ParseInt(string(prop.InnerXML), 10, 64) - if err == nil { - err = model.DB.Model(file.File).UpdateColumn("updated_at", time.Unix(modtimeUnix, 0)).Error - } - } +func (m *metadataDeadProps) DeadProps() (map[xml.Name]Property, error) { + meta := m.f.Metadata() + res := make(map[xml.Name]Property) + for k, v := range meta { + if !strings.HasPrefix(k, DeadPropsMetadataPrefix) { + continue } - } - return []Propstat{stat}, err -} -type FolderDeadProps struct { - *model.Folder -} + spaceLocal := strings.SplitN(strings.TrimPrefix(k, DeadPropsMetadataPrefix), SpaceNameSeparator, 2) + name := xml.Name{spaceLocal[0], spaceLocal[1]} + propsStore := &DeadPropsStore{} + if err := json.Unmarshal([]byte(v), propsStore); err != nil { + return nil, err + } + + res[name] = Property{ + XMLName: name, + InnerXML: propsStore.InnerXML, + Lang: propsStore.Lang, + } + } -func (folder *FolderDeadProps) DeadProps() (map[xml.Name]Property, error) { - return nil, nil + return res, nil } -func (folder *FolderDeadProps) Patch(proppatches []Proppatch) ([]Propstat, error) { - var ( - stat Propstat - err error - ) - stat.Status = http.StatusOK +func (m *metadataDeadProps) Patch(ctx context.Context, proppatches []Proppatch) ([]Propstat, error) { + metadataArgs := make([]fs.MetadataPatch, 0, len(proppatches)) + pstat := Propstat{Status: http.StatusOK} for _, patch := range proppatches { + translateFn := func(p Property) (*fs.MetadataPatch, error) { + val, err := json.Marshal(&DeadPropsStore{ + Lang: p.Lang, + InnerXML: p.InnerXML, + }) + if err != nil { + return nil, err + } + return &fs.MetadataPatch{ + Key: DeadPropsMetadataPrefix + p.XMLName.Space + SpaceNameSeparator + p.XMLName.Local, + Value: string(val), + }, nil + } + if patch.Remove { + translateFn = func(p Property) (*fs.MetadataPatch, error) { + return &fs.MetadataPatch{ + Key: DeadPropsMetadataPrefix + p.XMLName.Space + SpaceNameSeparator + p.XMLName.Local, + Remove: true, + }, nil + } + } for _, prop := range patch.Props { - stat.Props = append(stat.Props, Property{XMLName: prop.XMLName}) - if prop.XMLName.Space == "DAV:" && prop.XMLName.Local == "lastmodified" { - var modtimeUnix int64 - modtimeUnix, err = strconv.ParseInt(string(prop.InnerXML), 10, 64) - if err == nil { - err = model.DB.Model(folder.Folder).UpdateColumn("updated_at", time.Unix(modtimeUnix, 0)).Error - } + pstat.Props = append(pstat.Props, Property{XMLName: prop.XMLName}) + patch, err := translateFn(prop) + if err != nil { + return nil, err } + metadataArgs = append(metadataArgs, *patch) } } - return []Propstat{stat}, err + + if err := m.fm.PatchMedata(ctx, []*fs.URI{m.f.Uri(false)}, metadataArgs...); err != nil { + return nil, err + } + + return []Propstat{pstat}, nil } type FileInfo interface { @@ -172,14 +257,14 @@ type DeadPropsHolder interface { // // For more details on when various HTTP status codes apply, see // http://www.webdav.org/specs/rfc4918.html#PROPPATCH-status - Patch([]Proppatch) ([]Propstat, error) + Patch(context.Context, []Proppatch) ([]Propstat, error) } // liveProps contains all supported, protected DAV: properties. var liveProps = map[xml.Name]struct { // findFn implements the propfind function of this property. If nil, // it indicates a hidden property. - findFn func(context.Context, *filesystem.FileSystem, LockSystem, string, FileInfo) (string, error) + findFn func(context.Context, manager.FileManager, fs.File) (string, error) // dir is true if the property applies to directories. dir bool }{ @@ -207,8 +292,8 @@ var liveProps = map[xml.Name]struct { dir: true, }, {Space: "DAV:", Local: "creationdate"}: { - findFn: nil, - dir: false, + findFn: findCreationDate, + dir: true, }, {Space: "DAV:", Local: "getcontentlanguage"}: { findFn: nil, @@ -234,6 +319,14 @@ var liveProps = map[xml.Name]struct { findFn: findSupportedLock, dir: true, }, + {Space: "DAV:", Local: "quota-used-bytes"}: { + findFn: findQuotaUsedBytes, + dir: true, + }, + {Space: "DAV:", Local: "quota-available-bytes"}: { + findFn: findQuotaAvailableBytes, + dir: true, + }, } // TODO(nigeltao) merge props and allprop? @@ -242,19 +335,20 @@ var liveProps = map[xml.Name]struct { // // Each Propstat has a unique status and each property name will only be part // of one Propstat element. -func props(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, fi FileInfo, pnames []xml.Name) ([]Propstat, error) { - isDir := fi.IsDir() - if !isDir { - fi = &FileDeadProps{fi.(*model.File)} +func props(c *gin.Context, file fs.File, fm manager.FileManager, pnames []xml.Name) ([]Propstat, error) { + isDir := file.Type() == types.FileTypeFolder + dph := &metadataDeadProps{ + f: file, + fm: fm, } - var deadProps map[xml.Name]Property - if dph, ok := fi.(DeadPropsHolder); ok { - var err error - deadProps, err = dph.DeadProps() - if err != nil { - return nil, err - } + var ( + deadProps map[xml.Name]Property + err error + ) + deadProps, err = dph.DeadProps() + if err != nil { + return nil, err } pstatOK := Propstat{Status: http.StatusOK} @@ -267,8 +361,14 @@ func props(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, fi Fil } // Otherwise, it must either be a live property or we don't know it. if prop := liveProps[pn]; prop.findFn != nil && (prop.dir || !isDir) { - innerXML, err := prop.findFn(ctx, fs, ls, fi.GetName(), fi) + innerXML, err := prop.findFn(c, fm, file) if err != nil { + if errors.Is(err, ErrNotImplemented) { + pstatNotFound.Props = append(pstatNotFound.Props, Property{ + XMLName: pn, + }) + continue + } return nil, err } pstatOK.Props = append(pstatOK.Props, Property{ @@ -285,21 +385,18 @@ func props(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, fi Fil } // Propnames returns the property names defined for resource name. -func propnames(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, fi FileInfo) ([]xml.Name, error) { - isDir := fi.IsDir() - if !isDir { - fi = &FileDeadProps{fi.(*model.File)} - } - +func propnames(c *gin.Context, file fs.File, fm manager.FileManager) ([]xml.Name, error) { var deadProps map[xml.Name]Property - if dph, ok := fi.(DeadPropsHolder); ok { - var err error - deadProps, err = dph.DeadProps() - if err != nil { - return nil, err - } + dph := &metadataDeadProps{ + f: file, + fm: fm, + } + deadProps, err := dph.DeadProps() + if err != nil { + return nil, err } + isDir := file.Type() == types.FileTypeFolder pnames := make([]xml.Name, 0, len(liveProps)+len(deadProps)) for pn, prop := range liveProps { if prop.findFn != nil && (prop.dir || !isDir) { @@ -320,8 +417,8 @@ func propnames(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, fi // returned if they are named in 'include'. // // See http://www.webdav.org/specs/rfc4918.html#METHOD_PROPFIND -func allprop(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, info FileInfo, include []xml.Name) ([]Propstat, error) { - pnames, err := propnames(ctx, fs, ls, info) +func allprop(c *gin.Context, file fs.File, fm manager.FileManager, include []xml.Name) ([]Propstat, error) { + pnames, err := propnames(c, file, fm) if err != nil { return nil, err } @@ -335,12 +432,12 @@ func allprop(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, info pnames = append(pnames, pn) } } - return props(ctx, fs, ls, info, pnames) + return props(c, file, fm, pnames) } // Patch patches the properties of resource name. The return values are // constrained in the same manner as DeadPropsHolder.Patch. -func patch(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, patches []Proppatch) ([]Propstat, error) { +func patch(c context.Context, file fs.File, fm manager.FileManager, patches []Proppatch) ([]Propstat, error) { conflict := false loop: for _, patch := range patches { @@ -372,37 +469,24 @@ loop: } // very unlikely to be false - exist, info := isPathExist(ctx, fs, name) - if exist { - var dph DeadPropsHolder - if info.IsDir() { - dph = &FolderDeadProps{info.(*model.Folder)} - } else { - dph = &FileDeadProps{info.(*model.File)} - } - ret, err := dph.Patch(patches) - if err != nil { - return nil, err - } - // http://www.webdav.org/specs/rfc4918.html#ELEMENT_propstat says that - // "The contents of the prop XML element must only list the names of - // properties to which the result in the status element applies." - for _, pstat := range ret { - for i, p := range pstat.Props { - pstat.Props[i] = Property{XMLName: p.XMLName} - } - } - return ret, nil + dph := &metadataDeadProps{ + f: file, + fm: fm, } - // The file doesn't implement the optional DeadPropsHolder interface, so - // all patches are forbidden. - pstat := Propstat{Status: http.StatusOK} - for _, patch := range patches { - for _, p := range patch.Props { - pstat.Props = append(pstat.Props, Property{XMLName: p.XMLName}) + + ret, err := dph.Patch(c, patches) + if err != nil { + return nil, err + } + // http://www.webdav.org/specs/rfc4918.html#ELEMENT_propstat says that + // "The contents of the prop XML element must only list the names of + // properties to which the result in the status element applies." + for _, pstat := range ret { + for i, p := range pstat.Props { + pstat.Props[i] = Property{XMLName: p.XMLName} } } - return []Propstat{pstat}, nil + return ret, nil } func escapeXML(s string) string { @@ -425,103 +509,68 @@ func escapeXML(s string) string { return s } -func findResourceType(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { - if fi.IsDir() { +// ErrNotImplemented should be returned by optional interfaces if they +// want the original implementation to be used. +var ErrNotImplemented = errors.New("not implemented") + +func findResourceType(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + if file.Type() == types.FileTypeFolder { return ``, nil } return "", nil } -func findDisplayName(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { - if slashClean(name) == "/" { - // Hide the real name of a possibly prefixed root directory. - return "", nil - } - return escapeXML(fi.GetName()), nil +func findDisplayName(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + return escapeXML(file.DisplayName()), nil } -func findContentLength(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { - return strconv.FormatUint(fi.GetSize(), 10), nil +func findContentLength(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + return strconv.FormatInt(file.Size(), 10), nil } -func findLastModified(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { - return fi.ModTime().UTC().Format(http.TimeFormat), nil +func findLastModified(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + return file.UpdatedAt().UTC().Format(http.TimeFormat), nil } -// ErrNotImplemented should be returned by optional interfaces if they -// want the original implementation to be used. -var ErrNotImplemented = errors.New("not implemented") +func findCreationDate(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + return file.CreatedAt().UTC().Format(http.TimeFormat), nil +} -// ContentTyper is an optional interface for the os.FileInfo -// objects returned by the FileSystem. -// -// If this interface is defined then it will be used to read the -// content type from the object. -// -// If this interface is not defined the file will be opened and the -// content type will be guessed from the initial contents of the file. -type ContentTyper interface { - // ContentType returns the content type for the file. - // - // If this returns error ErrNotImplemented then the error will - // be ignored and the base implementation will be used - // instead. - ContentType(ctx context.Context) (string, error) +func findContentType(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + d := dependency.FromContext(ctx) + return d.MimeDetector(ctx).TypeByName(file.DisplayName()), nil } -func findContentType(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { - //if do, ok := fi.(ContentTyper); ok { - // ctype, err := do.ContentType(ctx) - // if err != ErrNotImplemented { - // return ctype, err - // } - //} - //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) - //if err != nil { - // return "", err - //} - //defer f.Close() - //// This implementation is based on serveContent's code in the standard net/http package. - //ctype := mime.TypeByExtension(filepath.Ext(name)) - //if ctype != "" { - // return ctype, nil - //} - //// Read a chunk to decide between utf-8 text and binary. - //var buf [512]byte - //n, err := io.ReadFull(f, buf[:]) - //if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { - // return "", err - //} - //ctype = http.DetectContentType(buf[:n]) - //// Rewind file. - //_, err = f.Seek(0, os.SEEK_SET) - //return ctype, err - return mime.TypeByExtension(filepath.Ext(name)), nil +func findETag(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + hasher := dependency.FromContext(ctx).HashIDEncoder() + return fmt.Sprintf(`"%s"`, hashid.EncodeEntityID(hasher, file.PrimaryEntityID())), nil } -// ETager is an optional interface for the os.FileInfo objects -// returned by the FileSystem. -// -// If this interface is defined then it will be used to read the ETag -// for the object. -// -// If this interface is not defined an ETag will be computed using the -// ModTime() and the Size() methods of the os.FileInfo object. -type ETager interface { - // ETag returns an ETag for the file. This should be of the - // form "value" or W/"value" - // - // If this returns error ErrNotImplemented then the error will - // be ignored and the base implementation will be used - // instead. - ETag(ctx context.Context) (string, error) +func findQuotaUsedBytes(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + requester := inventory.UserFromContext(ctx) + if file.Owner().ID != requester.ID { + return "", ErrNotImplemented + } + capacity, err := fm.Capacity(ctx) + if err != nil { + return "", err + } + return strconv.FormatInt(capacity.Used, 10), nil } -func findETag(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, reqPath string, fi FileInfo) (string, error) { - return fmt.Sprintf(`"%x%x"`, fi.ModTime().UnixNano(), fi.GetSize()), nil +func findQuotaAvailableBytes(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { + requester := inventory.UserFromContext(ctx) + if file.Owner().ID != requester.ID { + return "", ErrNotImplemented + } + capacity, err := fm.Capacity(ctx) + if err != nil { + return "", err + } + return strconv.FormatInt(capacity.Total-capacity.Used, 10), nil } -func findSupportedLock(ctx context.Context, fs *filesystem.FileSystem, ls LockSystem, name string, fi FileInfo) (string, error) { +func findSupportedLock(ctx context.Context, fm manager.FileManager, file fs.File) (string, error) { return `` + `` + `` + diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 5ab99068..da20873b 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -9,468 +9,573 @@ import ( "context" "errors" "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-gonic/gin" + "github.com/samber/lo" + "golang.org/x/tools/container/intsets" "net/http" - "net/http/httputil" "net/url" "path" - "strconv" "strings" - "sync" "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/util" ) -type Handler struct { - // Prefix is the URL path prefix to strip from WebDAV resource paths. - Prefix string - // LockSystem is the lock management system. - LockSystem map[uint]LockSystem - // Logger is an optional error logger. If non-nil, it will be called - // for all HTTP requests. - Logger func(*http.Request, error) - Mutex *sync.Mutex -} +const ( + davPrefix = "/dav" +) -func (h *Handler) stripPrefix(p string, uid uint) (string, int, error) { - if h.Prefix == "" { - return p, http.StatusOK, nil +func stripPrefix(p string, u *ent.User) (string, *fs.URI, int, error) { + base, err := fs.NewUriFromString(u.Edges.DavAccounts[0].URI) + if err != nil { + return "", nil, http.StatusInternalServerError, err } - prefix := h.Prefix + + prefix := davPrefix if r := strings.TrimPrefix(p, prefix); len(r) < len(p) { - if len(r) == 0 { - r = "/" - } - return util.RemoveSlash(r), http.StatusOK, nil + r = strings.TrimPrefix(r, fs.Separator) + return r, base.JoinRaw(util.RemoveSlash(r)), http.StatusOK, nil } - return p, http.StatusNotFound, errPrefixMismatch + return "", nil, http.StatusNotFound, errPrefixMismatch } -// isPathExist 路径是否存在 -func isPathExist(ctx context.Context, fs *filesystem.FileSystem, path string) (bool, FileInfo) { - // 尝试目录 - if ok, folder := fs.IsPathExist(path); ok { - return ok, folder +func ServeHTTP(c *gin.Context) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + fm := manager.NewFileManager(dep, u) + defer fm.Recycle() + + status, err := http.StatusBadRequest, errUnsupportedMethod + + switch c.Request.Method { + case "OPTIONS": + status, err = handleOptions(c, u, fm) + case "GET", "HEAD", "POST": + status, err = handleGetHeadPost(c, u, fm) + case "DELETE": + status, err = handleDelete(c, u, fm) + case "PUT": + status, err = handlePut(c, u, fm) + case "MKCOL": + status, err = handleMkcol(c, u, fm) + case "COPY", "MOVE": + status, err = handleCopyMove(c, u, fm) + case "LOCK": + status, err = handleLock(c, u, fm) + case "UNLOCK": + status, err = handleUnlock(c, u, fm) + case "PROPFIND": + status, err = handlePropfind(c, u, fm) + case "PROPPATCH": + status, err = handleProppatch(c, u, fm) } - if ok, file := fs.IsFileExist(path); ok { - return ok, file + if status != 0 { + c.Writer.WriteHeader(status) + if status != http.StatusNoContent { + c.Writer.Write([]byte(StatusText(status))) + } + } + + if err != nil { + dep.Logger().Debug("WebDAV request failed with error: %s", err) } - return false, nil } -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) { - status, err := http.StatusBadRequest, errUnsupportedMethod - h.Mutex.Lock() - if h.LockSystem == nil { - h.Mutex.Unlock() - status, err = http.StatusInternalServerError, errNoLockSystem - } else { - // 检查并新建 LockSystem - ls, ok := h.LockSystem[fs.User.ID] - if !ok { - h.LockSystem[fs.User.ID] = NewMemLS() - ls = h.LockSystem[fs.User.ID] +func confirmLock(c *gin.Context, fm manager.FileManager, user *ent.User, srcAnc, dstAnc fs.File, src, dst *fs.URI) (func(), fs.LockSession, int, error) { + hdr := c.Request.Header.Get("If") + if hdr == "" { + // An empty If header means that the client hasn't previously created locks. + // Even if this client doesn't care about locks, we still need to check that + // the resources aren't locked by another client, so we create temporary + // locks that would conflict with another client's locks. These temporary + // locks are unlocked at the end of the HTTP request. + srcToken, dstToken := "", "" + ap := fs.LockApp(fs.ApplicationDAV) + var ( + ctx context.Context = c + ls fs.LockSession + err error + ) + if src != nil { + ls, err = fm.Lock(ctx, -1, user, true, ap, src, "") + if err != nil { + return nil, nil, purposeStatusCodeFromError(err), err + } + srcToken = ls.LastToken() + ctx = fs.LockSessionToContext(ctx, ls) } - h.Mutex.Unlock() - - switch r.Method { - case "OPTIONS": - status, err = h.handleOptions(w, r, fs) - case "GET", "HEAD", "POST": - status, err = h.handleGetHeadPost(w, r, fs) - case "DELETE": - status, err = h.handleDelete(w, r, fs) - case "PUT": - status, err = h.handlePut(w, r, fs) - case "MKCOL": - status, err = h.handleMkcol(w, r, fs) - case "COPY", "MOVE": - status, err = h.handleCopyMove(w, r, fs) - case "LOCK": - status, err = h.handleLock(w, r, fs, ls) - case "UNLOCK": - status, err = h.handleUnlock(w, r, fs, ls) - case "PROPFIND": - status, err = h.handlePropfind(w, r, fs, ls) - case "PROPPATCH": - status, err = h.handleProppatch(w, r, fs, ls) + + if dst != nil { + ls, err = fm.Lock(ctx, -1, user, true, ap, dst, "") + if err != nil { + if src != nil { + _ = fm.Unlock(ctx, srcToken) + } + return nil, nil, purposeStatusCodeFromError(err), err + } + dstToken = ls.LastToken() + ctx = fs.LockSessionToContext(ctx, ls) } + + return func() { + if dstToken != "" { + _ = fm.Unlock(ctx, dstToken) + } + if srcToken != "" { + _ = fm.Unlock(ctx, srcToken) + } + }, ls, 0, nil } - if status != 0 { - w.WriteHeader(status) - if status != http.StatusNoContent { - w.Write([]byte(StatusText(status))) + ih, ok := parseIfHeader(hdr) + if !ok { + return nil, nil, http.StatusBadRequest, errInvalidIfHeader + } + // ih is a disjunction (OR) of ifLists, so any ifList will do. + for _, l := range ih.lists { + var ( + releaseSrc = func() {} + releaseDst = func() {} + ls fs.LockSession + err error + ) + if src != nil { + releaseSrc, ls, err = fm.ConfirmLock(c, srcAnc, src, lo.Map(l.conditions, func(c Condition, index int) string { + return c.Token + })...) + if errors.Is(err, lock.ErrConfirmationFailed) { + continue + } + if err != nil { + return nil, nil, purposeStatusCodeFromError(err), err + } } + + if dst != nil { + releaseDst, ls, err = fm.ConfirmLock(c, dstAnc, dst, lo.Map(l.conditions, func(c Condition, index int) string { + return c.Token + })...) + if errors.Is(err, lock.ErrConfirmationFailed) { + continue + } + if err != nil { + return nil, nil, purposeStatusCodeFromError(err), err + } + } + + return func() { + releaseDst() + releaseSrc() + }, ls, 0, nil } - if h.Logger != nil { - h.Logger(r, err) - } + // Section 10.4.1 says that "If this header is evaluated and all state lists + // fail, then the request must fail with a 412 (Precondition Failed) status." + // We follow the spec even though the cond_put_corrupt_token test case from + // the litmus test warns on seeing a 412 instead of a 423 (Locked). + return nil, nil, http.StatusPreconditionFailed, ErrLocked } -// OK -func (h *Handler) lock(now time.Time, root string, fs *filesystem.FileSystem, ls LockSystem) (token string, status int, err error) { - //token, err = ls.Create(now, LockDetails{ - // Root: root, - // Duration: infiniteTimeout, - // ZeroDepth: true, - //}) - //if err != nil { - // if err == ErrLocked { - // return "", StatusLocked, err - // } - // return "", http.StatusInternalServerError, err - //} - - return fmt.Sprintf("%d", time.Now().Unix()), 0, nil -} +func handleMkcol(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) + if err != nil { + return status, err + } -// ok -func (h *Handler) confirmLocks(r *http.Request, src, dst string, fs *filesystem.FileSystem) (release func(), status int, err error) { - - //hdr := r.Header.Get("If") - //h.Mutex.Lock() - //ls,ok := h.LockSystem[fs.User.ID] - //h.Mutex.Unlock() - //if !ok{ - // return nil, http.StatusInternalServerError, errNoLockSystem - //} - // - //if hdr == "" { - // // An empty If header means that the client hasn't previously created locks. - // // Even if this client doesn't care about locks, we still need to check that - // // the resources aren't locked by another client, so we create temporary - // // locks that would conflict with another client's locks. These temporary - // // locks are unlocked at the end of the HTTP request. - // now, srcToken, dstToken := time.Now(), "", "" - // if src != "" { - // srcToken, status, err = h.lock(now, src, fs,ls) - // if err != nil { - // return nil, status, err - // } - // } - // if dst != "" { - // dstToken, status, err = h.lock(now, dst, fs,ls) - // if err != nil { - // if srcToken != "" { - // ls.Unlock(now, srcToken) - // } - // return nil, status, err - // } - // } - // - // return func() { - // if dstToken != "" { - // ls.Unlock(now, dstToken) - // } - // if srcToken != "" { - // ls.Unlock(now, srcToken) - // } - // }, 0, nil - //} - // - //ih, ok := parseIfHeader(hdr) - //if !ok { - // return nil, http.StatusBadRequest, errInvalidIfHeader - //} - //// ih is a disjunction (OR) of ifLists, so any ifList will do. - //for _, l := range ih.lists { - // lsrc := l.resourceTag - // if lsrc == "" { - // lsrc = src - // } else { - // u, err := url.Parse(lsrc) - // if err != nil { - // continue - // } - // //if u.Host != r.Host { - // // continue - // //} - // lsrc, status, err = h.stripPrefix(u.Path, fs.User.ID) - // if err != nil { - // return nil, status, err - // } - // } - // release, err = ls.Confirm( - // time.Now(), - // lsrc, - // dst, - // l.conditions..., - // ) - // if err == ErrConfirmationFailed { - // continue - // } - // if err != nil { - // return nil, http.StatusInternalServerError, err - // } - // return release, 0, nil - //} - //// Section 10.4.1 says that "If this header is evaluated and all state lists - //// fail, then the request must fail with a 412 (Precondition Failed) status." - //// We follow the spec even though the cond_put_corrupt_token test case from - //// the litmus test warns on seeing a 412 instead of a 423 (Locked). - //return nil, http.StatusPreconditionFailed, ErrLocked - - return func() { - - }, 0, nil -} + ancestor, uri, err := fm.SharedAddressTranslation(c, reqPath) + if err != nil && !ent.IsNotFound(err) { + return purposeStatusCodeFromError(err), err + } -// OK -func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) + release, ls, status, err := confirmLock(c, fm, user, ancestor, nil, uri, nil) if err != nil { return status, err } - ctx := r.Context() - allow := "OPTIONS, LOCK, PUT, MKCOL" - if exist, fi := isPathExist(ctx, fs, reqPath); exist { - if fi.IsDir() { - allow = "OPTIONS, LOCK, DELETE, PROPPATCH, COPY, MOVE, UNLOCK, PROPFIND" - } else { - allow = "OPTIONS, LOCK, GET, HEAD, POST, DELETE, PROPPATCH, COPY, MOVE, UNLOCK, PROPFIND, PUT" - } + defer release() + ctx := fs.LockSessionToContext(c, ls) + + if c.Request.ContentLength > 0 { + return http.StatusUnsupportedMediaType, nil } - w.Header().Set("Allow", allow) - // http://www.webdav.org/specs/rfc4918.html#dav.compliance.classes - w.Header().Set("DAV", "1, 2") - // http://msdn.microsoft.com/en-au/library/cc250217.aspx - w.Header().Set("MS-Author-Via", "DAV") - return 0, nil -} -var proxy = &httputil.ReverseProxy{ - Director: func(request *http.Request) { - if target, ok := request.Context().Value(fsctx.WebDAVProxyUrlCtx).(*url.URL); ok { - request.URL.Scheme = target.Scheme - request.URL.Host = target.Host - request.URL.Path = target.Path - request.URL.RawPath = target.RawPath - request.URL.RawQuery = target.RawQuery - request.Host = target.Host - request.Header.Del("Authorization") - } - }, - ErrorHandler: func(writer http.ResponseWriter, request *http.Request, err error) { - writer.WriteHeader(http.StatusInternalServerError) - }, + _, err = fm.Create(ctx, uri, types.FileTypeFolder, dbfs.WithNoChainedCreation(), dbfs.WithErrorOnConflict()) + if err != nil { + return purposeStatusCodeFromError(err), err + } + + return http.StatusCreated, nil } -// OK -func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - defer fs.Recycle() +func handlePut(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) + if err != nil { + return status, err + } + + ancestor, uri, err := fm.SharedAddressTranslation(c, reqPath) + if err != nil && !ent.IsNotFound(err) { + return purposeStatusCodeFromError(err), err + } - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) + release, ls, status, err := confirmLock(c, fm, user, ancestor, nil, uri, nil) if err != nil { return status, err } + defer release() - ctx := r.Context() + ctx := fs.LockSessionToContext(c, ls) + // TODO(rost): Support the If-Match, If-None-Match headers? See bradfitz' + // comments in http.checkEtag. - exist, file := fs.IsFileExist(reqPath) - if !exist { - return http.StatusNotFound, nil + rc, fileSize, err := request.SniffContentLength(c.Request) + if err != nil { + return http.StatusBadRequest, err } - fs.SetTargetFile(&[]model.File{*file}) - rs, err := fs.Preview(ctx, 0, false) + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: uri, + //MimeType: c.Request.Header.Get("Content-Type"), + Size: fileSize, + }, + File: rc, + Mode: fs.ModeOverwrite, + } + + m := manager.NewFileManager(dependency.FromContext(ctx), user) + defer m.Recycle() + + // Update file + res, err := m.Update(ctx, fileData) if err != nil { - if err == filesystem.ErrObjectNotExist { - return http.StatusNotFound, err - } - return http.StatusInternalServerError, err + return purposeStatusCodeFromError(err), err } - etag, err := findETag(ctx, fs, nil, reqPath, &fs.FileTarget[0]) + etag, err := findETag(ctx, fm, res) if err != nil { return http.StatusInternalServerError, err } - w.Header().Set("ETag", etag) - if !rs.Redirect { - defer rs.Content.Close() - // 获取文件内容 - http.ServeContent(w, r, reqPath, fs.FileTarget[0].UpdatedAt, rs.Content) - return 0, nil - } + c.Writer.Header().Set("ETag", etag) + return http.StatusCreated, nil +} - if application, ok := r.Context().Value(fsctx.WebDAVCtx).(*model.Webdav); ok && application.UseProxy { - target, err := url.Parse(rs.URL) +func handleOptions(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + allow := []string{"OPTIONS", "LOCK", "PUT", "MKCOL"} + + if user != nil { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { - return http.StatusInternalServerError, err + return status, err } - - r = r.Clone(context.WithValue(r.Context(), fsctx.WebDAVProxyUrlCtx, target)) - // 忽略反向代理在传输错误时报错 - defer func() { - if err := recover(); err != nil && err != http.ErrAbortHandler { - panic(err) + if target, _, err := fm.SharedAddressTranslation(c, reqPath); err == nil { + allow = allow[:1] + read, update, del, create := true, true, true, true + if target.OwnerID() != user.ID { + update = false + del = false + create = false } - }() - proxy.ServeHTTP(w, r) - } else { - http.Redirect(w, r, rs.URL, 301) + if del { + allow = append(allow, "DELETE", "MOVE") + } + if read { + allow = append(allow, "COPY", "PROPFIND") + if target.Type() == types.FileTypeFile { + allow = append(allow, "GET", "HEAD", "POST") + } + } + if update || create { + allow = append(allow, "LOCK", "UNLOCK") + } + if update { + allow = append(allow, "PROPPATCH") + if target.Type() == types.FileTypeFile { + allow = append(allow, "PUT") + } + } + } else { + logging.FromContext(c).Debug("Handle options failed to get target: %s", err) + } } + c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + // http://www.webdav.org/specs/rfc4918.html#dav.compliance.classes + c.Writer.Header().Set("DAV", "1, 2") + // http://msdn.microsoft.com/en-au/library/cc250217.aspx + c.Writer.Header().Set("MS-Author-Via", "DAV") return 0, nil } -// OK -func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - defer fs.Recycle() - - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) +func handleGetHeadPost(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { return status, err } - release, status, err := h.confirmLocks(r, reqPath, "", fs) + target, _, err := fm.SharedAddressTranslation(c, reqPath) if err != nil { - return status, err + return purposeStatusCodeFromError(err), err } - defer release() - ctx := r.Context() + if target.Type() != types.FileTypeFile { + return http.StatusMethodNotAllowed, nil + } - // 尝试作为文件删除 - if ok, file := fs.IsFileExist(reqPath); ok { - if err := fs.Delete(ctx, []uint{}, []uint{file.ID}, false, false); err != nil { - return http.StatusMethodNotAllowed, err - } - return http.StatusNoContent, nil + es, err := fm.GetEntitySource(c, target.PrimaryEntityID()) + if err != nil { + return purposeStatusCodeFromError(err), err } - // 尝试作为目录删除 - if ok, folder := fs.IsPathExist(reqPath); ok { - if err := fs.Delete(ctx, []uint{folder.ID}, []uint{}, false, false); err != nil { - return http.StatusMethodNotAllowed, err + defer es.Close() + + es.Apply(entitysource.WithSpeedLimit(int64(user.Edges.Group.SpeedLimit))) + if es.ShouldInternalProxy() || + (user.Edges.DavAccounts[0].Options.Enabled(int(types.DavAccountProxy)) && + user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionWebDAVProxy))) { + es.Serve(c.Writer, c.Request) + } else { + settings := dependency.FromContext(c).SettingProvider() + expire := time.Now().Add(settings.EntityUrlValidDuration(c)) + src, err := es.Url(c, entitysource.WithExpire(&expire)) + if err != nil { + return purposeStatusCodeFromError(err), err } - return http.StatusNoContent, nil + c.Redirect(http.StatusFound, src.Url) } - return http.StatusNotFound, nil + return 0, nil } -// OK -func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) +func handleUnlock(c *gin.Context, user *ent.User, fm manager.FileManager) (retStatus int, retErr error) { + // http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the + // Lock-Token value is a Coded-URL. We strip its angle brackets. + t := c.Request.Header.Get("Lock-Token") + if len(t) < 2 || t[0] != '<' || t[len(t)-1] != '>' { + return http.StatusBadRequest, errInvalidLockToken + } + t = t[1 : len(t)-1] + err := fm.Unlock(c, t) + if err != nil { + return purposeStatusCodeFromError(err), err + } + + return http.StatusNoContent, err +} + +func handleLock(c *gin.Context, user *ent.User, fm manager.FileManager) (retStatus int, retErr error) { + duration, err := parseTimeout(c.Request.Header.Get("Timeout")) + if err != nil { + return http.StatusBadRequest, err + } + li, status, err := readLockInfo(c.Request.Body) if err != nil { return status, err } - release, status, err := h.confirmLocks(r, reqPath, "", fs) + + href, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { return status, err } - defer release() - // TODO(rost): Support the If-Match, If-None-Match headers? See bradfitz' - // comments in http.checkEtag. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ctx = context.WithValue(ctx, fsctx.HTTPCtx, r.Context()) - ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) - - fileSize, err := strconv.ParseUint(r.Header.Get("Content-Length"), 10, 64) - if err != nil { - return http.StatusMethodNotAllowed, err - } - fileName := path.Base(reqPath) - filePath := path.Dir(reqPath) - fileData := fsctx.FileStream{ - MimeType: r.Header.Get("Content-Type"), - File: r.Body, - Size: fileSize, - Name: fileName, - VirtualPath: filePath, - } - - // 判断文件是否已存在 - exist, originFile := fs.IsFileExist(reqPath) - if exist { - // 已存在,为更新操作 - - // 检查此文件是否有软链接 - fileList, err := model.RemoveFilesWithSoftLinks([]model.File{*originFile}) - if err == nil && len(fileList) == 0 { - // 如果包含软连接,应重新生成新文件副本,并更新source_name - originFile.SourceName = fs.GenerateSavePath(ctx, &fileData) - fileData.Mode &= ^fsctx.Overwrite - fs.Use("AfterUpload", filesystem.HookUpdateSourceName) - fs.Use("AfterUploadCanceled", filesystem.HookUpdateSourceName) - fs.Use("AfterValidateFailed", filesystem.HookUpdateSourceName) - } - fs.Use("BeforeUpload", filesystem.HookResetPolicy) - fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("BeforeUpload", filesystem.HookValidateCapacityDiff) - fs.Use("AfterUploadCanceled", filesystem.HookCleanFileContent) - fs.Use("AfterUploadCanceled", filesystem.HookClearFileSize) - fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) - fs.Use("AfterUpload", filesystem.GenericAfterUpdate) - fs.Use("AfterValidateFailed", filesystem.HookCleanFileContent) - fs.Use("AfterValidateFailed", filesystem.HookClearFileSize) - ctx = context.WithValue(ctx, fsctx.FileModelCtx, *originFile) - fileData.Mode |= fsctx.Overwrite + token, ld, created := "", lock.LockDetails{}, false + if li == (lockInfo{}) { + // An empty lockInfo means to refresh the lock. + ih, ok := parseIfHeader(c.Request.Header.Get("If")) + if !ok { + return http.StatusBadRequest, errInvalidIfHeader + } + if len(ih.lists) == 1 && len(ih.lists[0].conditions) == 1 { + token = ih.lists[0].conditions[0].Token + } + if token == "" { + return http.StatusBadRequest, errInvalidLockToken + } + ld, err = fm.Refresh(c, duration, token) + if err != nil { + if errors.Is(err, lock.ErrNoSuchLock) { + return http.StatusPreconditionFailed, err + } + return http.StatusInternalServerError, err + } + ld.Root = href } else { - // 给文件系统分配钩子 - fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("BeforeUpload", filesystem.HookValidateCapacity) - fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) - fs.Use("AfterUploadCanceled", filesystem.HookCancelContext) - fs.Use("AfterUpload", filesystem.GenericAfterUpload) - fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) + // Section 9.10.3 says that "If no Depth header is submitted on a LOCK request, + // then the request MUST act as if a "Depth:infinity" had been submitted." + depth := infiniteDepth + if hdr := c.Request.Header.Get("Depth"); hdr != "" { + depth = parseDepth(hdr) + if depth != 0 && depth != infiniteDepth { + // Section 9.10.3 says that "Values other than 0 or infinity must not be + // used with the Depth header on a LOCK method". + return http.StatusBadRequest, errInvalidDepth + } + } + + ancestor, uri, err := fm.SharedAddressTranslation(c, reqPath) + if err != nil && !ent.IsNotFound(err) { + return purposeStatusCodeFromError(err), err + } + + ld = lock.LockDetails{ + Root: href, + Duration: duration, + Owner: lock.Owner{Application: lock.Application{InnerXML: li.Owner.InnerXML}}, + ZeroDepth: depth == 0, + } + app := lock.Application{ + Type: string(fs.ApplicationDAV), + InnerXML: li.Owner.InnerXML, + } + ls, err := fm.Lock(c, duration, user, depth == 0, app, uri, "") + if err != nil { + if errors.Is(err, lock.ErrLocked) { + return StatusLocked, err + } + return http.StatusInternalServerError, err + } + token = ls.LastToken() + ctx := fs.LockSessionToContext(c, ls) + defer func() { + if retErr != nil { + _ = fm.Unlock(c, token) + } + }() + + // Create the resource if it didn't previously exist. + hasher := dependency.FromContext(c).HashIDEncoder() + if !ancestor.Uri(false).IsSame(uri, hashid.EncodeUserID(hasher, user.ID)) { + if _, err = fm.Create(ctx, uri, types.FileTypeFile); err != nil { + return purposeStatusCodeFromError(err), err + } + + created = true + } + + // http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the + // Lock-Token value is a Coded-URL. We add angle brackets. + c.Writer.Header().Set("Lock-Token", "<"+token+">") + } + + c.Writer.Header().Set("Content-Type", "application/xml; charset=utf-8") + if created { + // This is "w.WriteHeader(http.StatusCreated)" and not "return + // http.StatusCreated, nil" because we write our own (XML) response to w + // and Handler.ServeHTTP would otherwise write "Created". + c.Writer.WriteHeader(http.StatusCreated) } + writeLockInfo(c.Writer, token, ld) + return 0, nil +} - // rclone 请求 - fs.Use("AfterUpload", filesystem.NewWebdavAfterUploadHook(r)) +func handlePropfind(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + href, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) + if err != nil { + return status, err + } - // 执行上传 - err = fs.Upload(ctx, &fileData) + _, targetPath, err := fm.SharedAddressTranslation(c, reqPath) if err != nil { - return http.StatusMethodNotAllowed, err + return purposeStatusCodeFromError(err), err } - etag, err := findETag(ctx, fs, nil, reqPath, fileData.Model.(*model.File)) + depth := infiniteDepth + if hdr := c.Request.Header.Get("Depth"); hdr != "" { + depth = parseDepth(hdr) + if depth == invalidDepth { + return http.StatusBadRequest, errInvalidDepth + } + } + pf, status, err := readPropfind(c.Request.Body) if err != nil { - return http.StatusInternalServerError, err + return status, err + } + + mw := multistatusWriter{w: c.Writer} + walkFn := func(f fs.File, level int) error { + var pstats []Propstat + if pf.Propname != nil { + pnames, err := propnames(c, f, fm) + if err != nil { + return err + } + pstat := Propstat{Status: http.StatusOK} + for _, xmlname := range pnames { + pstat.Props = append(pstat.Props, Property{XMLName: xmlname}) + } + pstats = append(pstats, pstat) + } else if pf.Allprop != nil { + pstats, err = allprop(c, f, fm, pf.Prop) + } else { + pstats, err = props(c, f, fm, pf.Prop) + } + if err != nil { + return err + } + + p := path.Join(davPrefix, href) + elements := f.Uri(false).Elements() + for i := 0; i < level; i++ { + p = path.Join(p, elements[len(elements)-level+i]) + } + if f.Type() == types.FileTypeFolder { + p = util.FillSlash(p) + } + + return mw.write(makePropstatResponse(p, pstats)) } - w.Header().Set("ETag", etag) - return http.StatusCreated, nil -} -// OK -func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - defer fs.Recycle() + if err := fm.Walk(c, targetPath, depth, walkFn, dbfs.WithFilePublicMetadata()); err != nil { + return purposeStatusCodeFromError(err), err + } + + closeErr := mw.close() + if closeErr != nil { + return http.StatusInternalServerError, closeErr + } + return 0, nil +} - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) +func handleDelete(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { return status, err } - release, status, err := h.confirmLocks(r, reqPath, "", fs) + + ancestor, uri, err := fm.SharedAddressTranslation(c, reqPath) + if err != nil { + return purposeStatusCodeFromError(err), err + } + + release, ls, status, err := confirmLock(c, fm, user, ancestor, nil, uri, nil) if err != nil { return status, err } defer release() + ctx := fs.LockSessionToContext(c, ls) - ctx := r.Context() + // TODO: return MultiStatus where appropriate. - if r.ContentLength > 0 { - return http.StatusUnsupportedMediaType, nil + if err := fm.Delete(ctx, []*fs.URI{uri}); err != nil { + return purposeStatusCodeFromError(err), err } - if _, err := fs.CreateDirectory(ctx, reqPath); err != nil { - return http.StatusConflict, err - } - return http.StatusCreated, nil + return http.StatusNoContent, nil } -// OK -func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem) (status int, err error) { - defer fs.Recycle() - - hdr := r.Header.Get("Destination") +func handleCopyMove(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + hdr := c.Request.Header.Get("Destination") if hdr == "" { return http.StatusBadRequest, errInvalidDestination } @@ -478,51 +583,57 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request, fs *fil if err != nil { return http.StatusBadRequest, errInvalidDestination } - //if u.Host != "" && u.Host != r.Host { - // return http.StatusBadGateway, errInvalidDestination - //} + if u.Host != "" && u.Host != c.Request.Host { + return http.StatusBadGateway, errInvalidDestination + } - src, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) + _, src, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { return status, err } - dst, status, err := h.stripPrefix(u.Path, fs.User.ID) + srcTarget, srcUri, err := fm.SharedAddressTranslation(c, src) if err != nil { - return status, err + return purposeStatusCodeFromError(err), err } - if dst == "" { - return http.StatusBadGateway, errInvalidDestination - } - if dst == src { - return http.StatusForbidden, errDestinationEqualsSource + _, dst, status, err := stripPrefix(u.Path, user) + if err != nil { + return status, err } - ctx := r.Context() + dstTarget, dstUri, err := fm.SharedAddressTranslation(c, dst) + if err != nil && !ent.IsNotFound(err) { + return purposeStatusCodeFromError(err), err + } - isExist, target := isPathExist(ctx, fs, src) + _, dstFolderUri, err := fm.SharedAddressTranslation(c, dst.DirUri()) + if err != nil { + return purposeStatusCodeFromError(err), err + } - if !isExist { - return http.StatusNotFound, nil + hasher := dependency.FromContext(c).HashIDEncoder() + if srcUri.IsSame(dstUri, hashid.EncodeUserID(hasher, user.ID)) { + return http.StatusForbidden, errDestinationEqualsSource } - if r.Method == "COPY" { + if c.Request.Method == "COPY" { // Section 7.5.1 says that a COPY only needs to lock the destination, // not both destination and source. Strictly speaking, this is racy, // even though a COPY doesn't modify the source, if a concurrent // operation modifies the source. However, the litmus test explicitly // checks that COPYing a locked-by-another source is OK. - release, status, err := h.confirmLocks(r, "", dst, fs) + release, ls, status, err := confirmLock(c, fm, user, dstTarget, nil, dstUri, nil) if err != nil { return status, err } defer release() + ctx := fs.LockSessionToContext(c, ls) // Section 9.8.3 says that "The COPY method on a collection without a Depth // header must act as if a Depth header with value "infinity" was included". depth := infiniteDepth - if hdr := r.Header.Get("Depth"); hdr != "" { + if hdr := c.Request.Header.Get("Depth"); hdr != "" { depth = parseDepth(hdr) if depth != 0 && depth != infiniteDepth { // Section 9.8.3 says that "A client may submit a Depth header on a @@ -530,274 +641,68 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request, fs *fil return http.StatusBadRequest, errInvalidDepth } } - status, err = copyFiles(ctx, fs, target, dst, r.Header.Get("Overwrite") != "F", depth, 0) - if err != nil { - return status, err - } - err = updateCopyMoveModtime(r, fs, dst) - if err != nil { - return http.StatusInternalServerError, err + if err := fm.MoveOrCopy(ctx, []*fs.URI{srcUri}, dstFolderUri, true); err != nil { + return purposeStatusCodeFromError(err), err } - return status, nil } - // windows下,某些情况下(网盘根目录下)Office保存文件时附带的锁token只包含源文件, - // 此处暂时去除了对dst锁的检查 - release, status, err := h.confirmLocks(r, src, "", fs) + release, ls, status, err := confirmLock(c, fm, user, srcTarget, dstTarget, srcUri, dstUri) if err != nil { return status, err } defer release() + ctx := fs.LockSessionToContext(c, ls) // Section 9.9.2 says that "The MOVE method on a collection must act as if // a "Depth: infinity" header was used on it. A client must not submit a // Depth header on a MOVE on a collection with any value but "infinity"." - if hdr := r.Header.Get("Depth"); hdr != "" { + if hdr := c.Request.Header.Get("Depth"); hdr != "" { if parseDepth(hdr) != infiniteDepth { return http.StatusBadRequest, errInvalidDepth } } - status, err = moveFiles(ctx, fs, target, dst, r.Header.Get("Overwrite") == "T") - if err != nil { - return status, err + if err := fm.MoveOrCopy(ctx, []*fs.URI{srcUri}, dstFolderUri, false); err != nil { + return purposeStatusCodeFromError(err), err } - err = updateCopyMoveModtime(r, fs, dst) - if err != nil { - return http.StatusInternalServerError, err - } - return status, nil -} - -// OK -func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem, ls LockSystem) (retStatus int, retErr error) { - defer fs.Recycle() - - duration, err := parseTimeout(r.Header.Get("Timeout")) - if err != nil { - return http.StatusBadRequest, err - } - - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) - if err != nil { - return status, err + if dstUri.Name() != srcUri.Name() { + if _, err := fm.Rename(ctx, dstFolderUri.Join(srcUri.Name()), dstUri.Name()); err != nil { + return purposeStatusCodeFromError(err), err + } } - ////ctx := r.Context() - //token, ld, now, created := "", LockDetails{}, time.Now(), false - //if li == (lockInfo{}) { - // // An empty lockInfo means to refresh the lock. - // ih, ok := parseIfHeader(r.Header.Get("If")) - // if !ok { - // return http.StatusBadRequest, errInvalidIfHeader - // } - // if len(ih.lists) == 1 && len(ih.lists[0].conditions) == 1 { - // token = ih.lists[0].conditions[0].Token - // } - // if token == "" { - // return http.StatusBadRequest, errInvalidLockToken - // } - // ld, err = ls.Refresh(now, token, duration) - // if err != nil { - // if err == ErrNoSuchLock { - // return http.StatusPreconditionFailed, err - // } - // return http.StatusInternalServerError, err - // } - // - //} else { - // // Section 9.10.3 says that "If no Depth header is submitted on a LOCK request, - // // then the request MUST act as if a "Depth:infinity" had been submitted." - // depth := infiniteDepth - // if hdr := r.Header.Get("Depth"); hdr != "" { - // depth = parseDepth(hdr) - // if depth != 0 && depth != infiniteDepth { - // // Section 9.10.3 says that "Values other than 0 or infinity must not be - // // used with the Depth header on a LOCK method". - // return http.StatusBadRequest, errInvalidDepth - // } - // } - // reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) - // if err != nil { - // return status, err - // } - // ld = LockDetails{ - // Root: reqPath, - // Duration: duration, - // OwnerXML: li.Owner.InnerXML, - // ZeroDepth: depth == 0, - // } - // token, err = ls.Create(now, ld) - // if err != nil { - // if err == ErrLocked { - // return StatusLocked, err - // } - // return http.StatusInternalServerError, err - // } - // defer func() { - // if retErr != nil { - // ls.Unlock(now, token) - // } - // }() - // - // // Create the resource if it didn't previously exist. - // //if _, err := h.FileSystem.Stat(ctx, reqPath); err != nil { - // // f, err := h.FileSystem.OpenFile(ctx, reqPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) - // // if err != nil { - // // // TODO: detect missing intermediate dirs and return http.StatusConflict? - // // return http.StatusInternalServerError, err - // // } - // // f.Close() - // // created = true - // //} - // - // // http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the - // // Lock-Token value is a Coded-URL. We add angle brackets. - // w.Header().Set("Lock-Token", "<"+token+">") - //} - // - //w.Header().Set("Content-Type", "application/xml; charset=utf-8") - //if created { - // // This is "w.WriteHeader(http.StatusCreated)" and not "return - // // http.StatusCreated, nil" because we write our own (XML) response to w - // // and Handler.ServeHTTP would otherwise write "Created". - // w.WriteHeader(http.StatusCreated) - //} - - writeLockInfo(w, fmt.Sprintf("%d", time.Now().UnixNano()), LockDetails{ - Duration: duration, - OwnerXML: fs.User.Email, - Root: reqPath, - }) - return 0, nil -} - -// OK -func (h *Handler) handleUnlock(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem, ls LockSystem) (status int, err error) { - defer fs.Recycle() - return http.StatusNoContent, err - - //// http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the - //// Lock-Token value is a Coded-URL. We strip its angle brackets. - //t := r.Header.Get("Lock-Token") - //if len(t) < 2 || t[0] != '<' || t[len(t)-1] != '>' { - // return http.StatusBadRequest, errInvalidLockToken - //} - //t = t[1 : len(t)-1] - // - //switch err = ls.Unlock(time.Now(), t); err { - //case nil: - // return http.StatusNoContent, err - //case ErrForbidden: - // return http.StatusForbidden, err - //case ErrLocked: - // return StatusLocked, err - //case ErrNoSuchLock: - // return http.StatusConflict, err - //default: - // return http.StatusInternalServerError, err - //} + return http.StatusNoContent, nil } -// OK -func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem, ls LockSystem) (status int, err error) { - defer fs.Recycle() - - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) +func handleProppatch(c *gin.Context, user *ent.User, fm manager.FileManager) (status int, err error) { + _, reqPath, status, err := stripPrefix(c.Request.URL.Path, user) if err != nil { return status, err } - ctx := r.Context() - ok, fi := isPathExist(ctx, fs, reqPath) - if !ok { - return http.StatusNotFound, err - } - depth := infiniteDepth - if hdr := r.Header.Get("Depth"); hdr != "" { - depth = parseDepth(hdr) - if depth == invalidDepth { - return http.StatusBadRequest, errInvalidDepth - } - } - pf, status, err := readPropfind(r.Body) + ancestor, uri, err := fm.SharedAddressTranslation(c, reqPath) if err != nil { - return status, err + return purposeStatusCodeFromError(err), err } - mw := multistatusWriter{w: w} - - walkFn := func(reqPath string, info FileInfo, err error) error { - - if err != nil { - return err - } - var pstats []Propstat - if pf.Propname != nil { - pnames, err := propnames(ctx, fs, ls, info) - if err != nil { - return err - } - pstat := Propstat{Status: http.StatusOK} - for _, xmlname := range pnames { - pstat.Props = append(pstat.Props, Property{XMLName: xmlname}) - } - pstats = append(pstats, pstat) - } else if pf.Allprop != nil { - pstats, err = allprop(ctx, fs, ls, info, pf.Prop) - } else { - pstats, err = props(ctx, fs, ls, info, pf.Prop) - } - if err != nil { - return err - } - href := path.Join(h.Prefix, reqPath) - if href != "/" && info.IsDir() { - href += "/" - } - return mw.write(makePropstatResponse(href, pstats)) - } - - walkErr := walkFS(ctx, fs, depth, reqPath, fi, walkFn) - closeErr := mw.close() - if walkErr != nil { - return http.StatusInternalServerError, walkErr - } - if closeErr != nil { - return http.StatusInternalServerError, closeErr - } - return 0, nil -} - -func (h *Handler) handleProppatch(w http.ResponseWriter, r *http.Request, fs *filesystem.FileSystem, ls LockSystem) (status int, err error) { - defer fs.Recycle() - - reqPath, status, err := h.stripPrefix(r.URL.Path, fs.User.ID) - if err != nil { - return status, err - } - release, status, err := h.confirmLocks(r, reqPath, "", fs) + release, ls, status, err := confirmLock(c, fm, user, ancestor, nil, uri, nil) if err != nil { return status, err } defer release() + ctx := fs.LockSessionToContext(c, ls) - ctx := r.Context() - - if exist, _ := isPathExist(ctx, fs, reqPath); !exist { - return http.StatusNotFound, nil - } - patches, status, err := readProppatch(r.Body) + patches, status, err := readProppatch(c.Request.Body) if err != nil { return status, err } - pstats, err := patch(ctx, fs, ls, reqPath, patches) + pstats, err := patch(ctx, ancestor, fm, patches) if err != nil { return http.StatusInternalServerError, err } - mw := multistatusWriter{w: w} - writeErr := mw.write(makePropstatResponse(r.URL.Path, pstats)) + mw := multistatusWriter{w: c.Writer} + writeErr := mw.write(makePropstatResponse(c.Request.URL.Path, pstats)) closeErr := mw.close() if writeErr != nil { return http.StatusInternalServerError, writeErr @@ -808,6 +713,39 @@ func (h *Handler) handleProppatch(w http.ResponseWriter, r *http.Request, fs *fi return 0, nil } +func purposeStatusCodeFromError(err error) int { + if ent.IsNotFound(err) { + return http.StatusNotFound + } + + if errors.Is(err, lock.ErrNoSuchLock) { + return http.StatusConflict + } + + var ae *serializer.AggregateError + if errors.As(err, &ae) && len(ae.Raw()) > 0 { + for _, e := range ae.Raw() { + return purposeStatusCodeFromError(e) + } + } + + var appErr serializer.AppError + if errors.As(err, &appErr) { + switch appErr.Code { + case serializer.CodeNotFound, serializer.CodeParentNotExist, serializer.CodeEntityNotExist: + return http.StatusNotFound + case serializer.CodeNoPermissionErr: + return http.StatusForbidden + case serializer.CodeLockConflict: + return http.StatusLocked + case serializer.CodeObjectExist: + return http.StatusMethodNotAllowed + } + } + + return http.StatusInternalServerError +} + func makePropstatResponse(href string, pstats []Propstat) *response { resp := response{ Href: []string{(&url.URL{Path: href}).EscapedPath()}, @@ -829,7 +767,7 @@ func makePropstatResponse(href string, pstats []Propstat) *response { } const ( - infiniteDepth = -1 + infiniteDepth = intsets.MaxInt invalidDepth = -2 ) diff --git a/pkg/webdav/xml.go b/pkg/webdav/xml.go index 34b4f668..1518d58c 100644 --- a/pkg/webdav/xml.go +++ b/pkg/webdav/xml.go @@ -11,6 +11,7 @@ import ( "bytes" "encoding/xml" "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" "io" "net/http" "time" @@ -32,7 +33,7 @@ import ( // In the long term, this package should use the standard library's version // only, and the internal fork deleted, once // https://github.com/golang/go/issues/13400 is resolved. - ixml "github.com/cloudreve/Cloudreve/v3/pkg/webdav/internal/xml" + ixml "github.com/cloudreve/Cloudreve/v4/pkg/webdav/internal/xml" ) // http://www.webdav.org/specs/rfc4918.html#ELEMENT_lockinfo @@ -81,7 +82,7 @@ func (c *countingReader) Read(p []byte) (int, error) { return n, err } -func writeLockInfo(w io.Writer, token string, ld LockDetails) (int, error) { +func writeLockInfo(w io.Writer, token string, ld lock.LockDetails) (int, error) { depth := "infinity" if ld.ZeroDepth { depth = "0" @@ -95,9 +96,9 @@ func writeLockInfo(w io.Writer, token string, ld LockDetails) (int, error) { " %s\n"+ " Second-%d\n"+ " %s\n"+ - " %s\n"+ + " /%s\n"+ "", - depth, ld.OwnerXML, timeout, escape(token), escape(ld.Root), + depth, ld.Owner.Application.InnerXML, timeout, escape(token), escape(ld.Root), ) } diff --git a/pkg/wopi/discovery.go b/pkg/wopi/discovery.go index a9b6944e..64117309 100644 --- a/pkg/wopi/discovery.go +++ b/pkg/wopi/discovery.go @@ -3,9 +3,9 @@ package wopi import ( "encoding/xml" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "net/http" - "strings" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gofrs/uuid" + "github.com/samber/lo" ) type ActonType string @@ -16,86 +16,53 @@ var ( ActionEdit = ActonType("edit") ) -const ( - DiscoverResponseCacheKey = "wopi_discover" - DiscoverRefreshDuration = 24 * 3600 // 24 hrs -) - -func (c *client) AvailableExts() []string { - if err := c.checkDiscovery(); err != nil { - util.Log().Error("Failed to check WOPI discovery: %s", err) - return nil - } - - c.mu.RLock() - defer c.mu.RUnlock() - exts := make([]string, 0, len(c.actions)) - for ext, actions := range c.actions { - _, previewable := actions[string(ActionPreview)] - _, editable := actions[string(ActionEdit)] - _, previewableFallback := actions[string(ActionPreviewFallback)] - - if previewable || editable || previewableFallback { - exts = append(exts, strings.TrimPrefix(ext, ".")) - } +func DiscoveryXmlToViewerGroup(xmlStr string) (*setting.ViewerGroup, error) { + var discovery WopiDiscovery + if err := xml.Unmarshal([]byte(xmlStr), &discovery); err != nil { + return nil, fmt.Errorf("failed to parse WOPI discovery XML: %w", err) } - return exts -} - -// checkDiscovery checks if discovery content is needed to be refreshed. -// If so, it will refresh discovery content. -func (c *client) checkDiscovery() error { - c.mu.RLock() - if c.discovery == nil { - c.mu.RUnlock() - return c.refreshDiscovery() + group := &setting.ViewerGroup{ + Viewers: make([]setting.Viewer, 0, len(discovery.NetZone.App)), } - c.mu.RUnlock() - return nil -} - -// refresh Discovery action configs. -func (c *client) refreshDiscovery() error { - c.mu.Lock() - defer c.mu.Unlock() - - cached, exist := c.cache.Get(DiscoverResponseCacheKey) - if exist { - cachedDiscovery := cached.(WopiDiscovery) - c.discovery = &cachedDiscovery - } else { - res, err := c.http.Request("GET", c.config.discoveryEndpoint.String(), nil). - CheckHTTPResponse(http.StatusOK).GetResponse() - if err != nil { - return fmt.Errorf("failed to request discovery endpoint: %w", err) - } - - if err := xml.Unmarshal([]byte(res), &c.discovery); err != nil { - return fmt.Errorf("failed to parse response discovery endpoint: %w", err) - } - - if err := c.cache.Set(DiscoverResponseCacheKey, *c.discovery, DiscoverRefreshDuration); err != nil { - return err + for _, app := range discovery.NetZone.App { + viewer := setting.Viewer{ + ID: uuid.Must(uuid.NewV4()).String(), + DisplayName: app.Name, + Type: setting.ViewerTypeWopi, + Icon: app.FavIconUrl, + WopiActions: make(map[string]map[setting.ViewerAction]string), } - } - // construct actions map - c.actions = make(map[string]map[string]Action) - for _, app := range c.discovery.NetZone.App { for _, action := range app.Action { if action.Ext == "" { continue } - if _, ok := c.actions["."+action.Ext]; !ok { - c.actions["."+action.Ext] = make(map[string]Action) + if _, ok := viewer.WopiActions[action.Ext]; !ok { + viewer.WopiActions[action.Ext] = make(map[setting.ViewerAction]string) + } + + if action.Name == string(ActionPreview) { + viewer.WopiActions[action.Ext][setting.ViewerActionView] = action.Urlsrc + } else if action.Name == string(ActionPreviewFallback) { + viewer.WopiActions[action.Ext][setting.ViewerActionView] = action.Urlsrc + } else if action.Name == string(ActionEdit) { + viewer.WopiActions[action.Ext][setting.ViewerActionEdit] = action.Urlsrc + } else if len(viewer.WopiActions[action.Ext]) == 0 { + delete(viewer.WopiActions, action.Ext) } + } + + viewer.Exts = lo.MapToSlice(viewer.WopiActions, func(key string, value map[setting.ViewerAction]string) string { + return key + }) - c.actions["."+action.Ext][action.Name] = action + if len(viewer.WopiActions) > 0 { + group.Viewers = append(group.Viewers, viewer) } } - return nil + return group, nil } diff --git a/pkg/wopi/discovery_test.go b/pkg/wopi/discovery_test.go index 80923847..1a945aea 100644 --- a/pkg/wopi/discovery_test.go +++ b/pkg/wopi/discovery_test.go @@ -1,129 +1,438 @@ package wopi import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io" - "net/http" - "net/url" - "strings" + "fmt" "testing" ) -func TestClient_AvailableExts(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // Discovery failed - { - expectedErr := errors.New("error") - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: expectedErr, - }) - res := client.AvailableExts() - a.Empty(res) - mockHttp.AssertExpectations(t) - } - - // pass - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{}, - }, - ".ppt": {}, - ".xls": { - "not_supported": Action{}, - }, - } - res := client.AvailableExts() - a.Len(res, 1) - a.Equal("doc", res[0]) - } -} - -func TestClient_RefreshDiscovery(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // cache hit - { - client.cache.Set(DiscoverResponseCacheKey, WopiDiscovery{Text: "test"}, 0) - a.NoError(client.checkDiscovery()) - a.Equal("test", client.discovery.Text) - client.discovery = &WopiDiscovery{} - client.cache.Delete([]string{DiscoverResponseCacheKey}, "") - } - - // malformed xml - { - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - res := client.refreshDiscovery() - a.ErrorContains(res, "failed to parse") - mockHttp.AssertExpectations(t) - } - - // all pass - { - testResponse := ` -` - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(testResponse)), - }, - }) - res := client.refreshDiscovery() - a.NoError(res, res) - a.NotEmpty(client.actions[".docx"]) - a.NotEmpty(client.actions[".docx"][string(ActionPreview)]) - a.NotEmpty(client.actions[".docx"][string(ActionEdit)]) - mockHttp.AssertExpectations(t) - } +func TestDiscoveryXmlToViewerGroup(t *testing.T) { + xmlSrc := ` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +` + group, _ := DiscoveryXmlToViewerGroup(xmlSrc) + fmt.Print(group) } diff --git a/pkg/wopi/types.go b/pkg/wopi/types.go index a9425f4a..97b69335 100644 --- a/pkg/wopi/types.go +++ b/pkg/wopi/types.go @@ -63,6 +63,10 @@ type SessionCache struct { Action ActonType } +const ( + WopiSessionCtx = "wopi_session" +) + func init() { gob.Register(WopiDiscovery{}) gob.Register(Action{}) diff --git a/pkg/wopi/wopi.go b/pkg/wopi/wopi.go index 2938de04..49f7a38b 100644 --- a/pkg/wopi/wopi.go +++ b/pkg/wopi/wopi.go @@ -1,34 +1,22 @@ package wopi import ( + "context" "errors" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gofrs/uuid" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" "net/url" - "path" "strings" - "sync" "time" ) -type Client interface { - // NewSession creates a new document session with access token. - NewSession(uid uint, file *model.File, action ActonType) (*Session, error) - // AvailableExts returns a list of file extensions that are supported by WOPI. - AvailableExts() []string -} - var ( ErrActionNotSupported = errors.New("action not supported by current wopi endpoint") - Default Client - DefaultMu sync.Mutex - queryPlaceholders = map[string]string{ "BUSINESS_USER": "", "DC_LLCC": "lng", @@ -48,136 +36,55 @@ var ( const ( SessionCachePrefix = "wopi_session_" AccessTokenQuery = "access_token" - OverwriteHeader = wopiHeaderPrefix + "Override" - ServerErrorHeader = wopiHeaderPrefix + "ServerError" - RenameRequestHeader = wopiHeaderPrefix + "RequestedName" + OverwriteHeader = WopiHeaderPrefix + "Override" + ServerErrorHeader = WopiHeaderPrefix + "ServerError" + RenameRequestHeader = WopiHeaderPrefix + "RequestedName" + LockTokenHeader = WopiHeaderPrefix + "Lock" + ItemVersionHeader = WopiHeaderPrefix + "ItemVersion" MethodLock = "LOCK" MethodUnlock = "UNLOCK" MethodRefreshLock = "REFRESH_LOCK" - MethodRename = "RENAME_FILE" - - wopiSrcPlaceholder = "WOPI_SOURCE" - wopiSrcParamDefault = "WOPISrc" - languageParamDefault = "lang" - sessionExpiresPadding = 10 - wopiHeaderPrefix = "X-WOPI-" -) - -// Init initializes a new global WOPI client. -func Init() { - settings := model.GetSettingByNames("wopi_endpoint", "wopi_enabled") - if !model.IsTrueVal(settings["wopi_enabled"]) { - DefaultMu.Lock() - Default = nil - DefaultMu.Unlock() - return - } - - cache.Deletes([]string{DiscoverResponseCacheKey}, "") - wopiClient, err := NewClient(settings["wopi_endpoint"], cache.Store, request.NewClient()) - if err != nil { - util.Log().Error("Failed to initialize WOPI client: %s", err) - return - } - - DefaultMu.Lock() - Default = wopiClient - DefaultMu.Unlock() -} - -type client struct { - cache cache.Driver - http request.Client - mu sync.RWMutex - - discovery *WopiDiscovery - actions map[string]map[string]Action - - config -} - -type config struct { - discoveryEndpoint *url.URL -} -func NewClient(endpoint string, cache cache.Driver, http request.Client) (Client, error) { - endpointUrl, err := url.Parse(endpoint) - if err != nil { - return nil, fmt.Errorf("failed to parse WOPI endpoint: %s", err) - } + wopiSrcPlaceholder = "WOPI_SOURCE" + wopiSrcParamDefault = "WOPISrc" + languageParamDefault = "lang" + WopiHeaderPrefix = "X-WOPI-" - return &client{ - cache: cache, - http: http, - config: config{ - discoveryEndpoint: endpointUrl, - }, - }, nil -} - -func (c *client) NewSession(uid uint, file *model.File, action ActonType) (*Session, error) { - if err := c.checkDiscovery(); err != nil { - return nil, err - } + LockDuration = time.Duration(30) * time.Minute +) - c.mu.RLock() - defer c.mu.RUnlock() +func GenerateWopiSrc(ctx context.Context, action setting.ViewerAction, viewer *setting.Viewer, viewerSession *manager.ViewerSession) (*url.URL, error) { + dep := dependency.FromContext(ctx) + base := dep.SettingProvider().SiteURL(setting.UseFirstSiteUrl(ctx)) + hasher := dep.HashIDEncoder() - ext := path.Ext(file.Name) - availableActions, ok := c.actions[ext] + availableActions, ok := viewer.WopiActions[viewerSession.File.Ext()] if !ok { return nil, ErrActionNotSupported } var ( - actionConfig Action + src string ) - fallbackOrder := []ActonType{action, ActionPreview, ActionPreviewFallback, ActionEdit} + fallbackOrder := []setting.ViewerAction{action, setting.ViewerActionView, setting.ViewerActionEdit} for _, a := range fallbackOrder { - if actionConfig, ok = availableActions[string(a)]; ok { + if src, ok = availableActions[a]; ok { break } } - if actionConfig.Urlsrc == "" { + if src == "" { return nil, ErrActionNotSupported } - // Generate WOPI REST endpoint for given file - baseURL := model.GetSiteURL() - linkPath, err := url.Parse(fmt.Sprintf("/api/v3/wopi/files/%s", hashid.HashID(file.ID, hashid.FileID))) - if err != nil { - return nil, err - } - - actionUrl, err := generateActionUrl(actionConfig.Urlsrc, baseURL.ResolveReference(linkPath).String()) + actionUrl, err := generateActionUrl(src, + routes.MasterWopiSrc(base, hashid.EncodeFileID(hasher, viewerSession.File.ID())).String()) if err != nil { return nil, err } - // Create document session - sessionID := uuid.Must(uuid.NewV4()) - token := util.RandStringRunes(64) - ttl := model.GetIntSetting("wopi_session_timeout", 36000) - session := &SessionCache{ - AccessToken: fmt.Sprintf("%s.%s", sessionID, token), - FileID: file.ID, - UserID: uid, - Action: action, - } - err = c.cache.Set(SessionCachePrefix+sessionID.String(), *session, ttl) - if err != nil { - return nil, fmt.Errorf("failed to create document session: %w", err) - } - - sessionRes := &Session{ - AccessToken: session.AccessToken, - ActionURL: actionUrl, - AccessTokenTTL: time.Now().Add(time.Duration(ttl-sessionExpiresPadding) * time.Second).UnixMilli(), - } - - return sessionRes, nil + return actionUrl, nil } // Replace query parameters in action URL template. Some placeholders need to be replaced diff --git a/pkg/wopi/wopi_test.go b/pkg/wopi/wopi_test.go deleted file mode 100644 index 78c4bcc4..00000000 --- a/pkg/wopi/wopi_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package wopi - -import ( - "database/sql" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/cachemock" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "net/url" - "testing" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestNewSession(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // Discovery failed - { - expectedErr := errors.New("error") - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: expectedErr, - }) - res, err := client.NewSession(0, &model.File{}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, expectedErr) - mockHttp.AssertExpectations(t) - } - - // not supported ext - { - client.discovery = &WopiDiscovery{} - client.actions = make(map[string]map[string]Action) - res, err := client.NewSession(0, &model.File{}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, ErrActionNotSupported) - } - - // preferred action not supported - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": {}, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, ErrActionNotSupported) - } - - // src url cannot be parsed - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: string([]byte{0x7f}), - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.Nil(res) - a.ErrorContains(err, "invalid control character in URL") - } - - // all pass - default placeholder - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc", - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.NotNil(res) - a.NoError(err) - resUrl := res.ActionURL.String() - a.Contains(resUrl, wopiSrcParamDefault) - } - - // all pass - with placeholders - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc?origin=preserved&", - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.NotNil(res) - a.NoError(err) - resUrl := res.ActionURL.String() - a.Contains(resUrl, "origin=preserved") - a.Contains(resUrl, "dc=lng") - a.Contains(resUrl, "src=") - a.NotContains(resUrl, "notsuported") - } - - // cache operation failed - { - mockCache := &cachemock.CacheClientMock{} - expectedErr := errors.New("error") - client.cache = mockCache - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc", - }, - }, - } - mockCache.On("Set", testMock.Anything, testMock.Anything, testMock.Anything).Return(expectedErr) - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.Nil(res) - a.ErrorIs(err, expectedErr) - } -} - -func TestInit(t *testing.T) { - a := assert.New(t) - - // not enabled - { - a.Nil(Default) - Default = &client{} - Init() - a.Nil(Default) - } - - // throw error - { - a.Nil(Default) - cache.Set("setting_wopi_enabled", "1", 0) - cache.Set("setting_wopi_endpoint", string([]byte{0x7f}), 0) - Init() - a.Nil(Default) - } - - // all pass - { - a.Nil(Default) - cache.Set("setting_wopi_enabled", "1", 0) - cache.Set("setting_wopi_endpoint", "", 0) - Init() - a.NotNil(Default) - } -} diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index 26d917e1..2be65b0a 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -1,511 +1,588 @@ package controllers import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "io" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" - "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/cloudreve/Cloudreve/v3/service/admin" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/admin" "github.com/gin-gonic/gin" ) // AdminSummary 获取管理站点概况 func AdminSummary(c *gin.Context) { - var service admin.NoParamService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Summary() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*admin.SummaryService](c, admin.SummaryParamCtx{}) + res, err := service.Summary(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminNews 获取社区新闻 -func AdminNews(c *gin.Context) { - tag := "announcements" - if c.Query("tag") != "" { - tag = c.Query("tag") +// AdminGetSettings 获取站点设定项 +func AdminGetSettings(c *gin.Context) { + service := ParametersFromContext[*admin.GetSettingService](c, admin.GetSettingParamCtx{}) + res, err := service.GetSetting(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } - r := request.NewClient() - res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3A"+tag+"&sort=-startTime&page%5Blimit%5D=10", nil) - if res.Err == nil { - io.Copy(c.Writer, res.Response.Body) + + c.JSON(200, serializer.Response{Data: res}) +} + +func AdminSetSettings(c *gin.Context) { + service := ParametersFromContext[*admin.SetSettingService](c, admin.SetSettingParamCtx{}) + res, err := service.SetSetting(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminChangeSetting 获取站点设定项 -func AdminChangeSetting(c *gin.Context) { - var service admin.BatchSettingChangeService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Change() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminListGroups 获取用户组列表 +func AdminListGroups(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.List(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminGetSetting 获取站点设置 -func AdminGetSetting(c *gin.Context) { - var service admin.BatchSettingGet - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Get() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminFetchWopi(c *gin.Context) { + service := ParametersFromContext[*admin.FetchWOPIDiscoveryService](c, admin.FetchWOPIDiscoveryParamCtx{}) + res, err := service.Fetch(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminGetGroups 获取用户组列表 -func AdminGetGroups(c *gin.Context) { - var service admin.NoParamService - if err := c.ShouldBindUri(&service); err == nil { - res := service.GroupList() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminTestThumbGenerator Tests thumb generator +func AdminTestThumbGenerator(c *gin.Context) { + service := ParametersFromContext[*admin.ThumbGeneratorTestService](c, admin.ThumbGeneratorTestParamCtx{}) + res, err := service.Test(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminReloadService 重新加载子服务 -func AdminReloadService(c *gin.Context) { - service := c.Param("service") - switch service { - case "email": - email.Init() - case "aria2": - aria2.Init(true, cluster.Default, mq.GlobalMQ) - case "wopi": - wopi.Init() +func AdminGetQueueMetrics(c *gin.Context) { + res, err := admin.GetQueueMetrics(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) +} - c.JSON(200, serializer.Response{}) +func AdminListPolicies(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Policies(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +func AdminGetPolicy(c *gin.Context) { + service := ParametersFromContext[*admin.SingleStoragePolicyService](c, admin.GetStoragePolicyParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) } // AdminSendTestMail 发送测试邮件 func AdminSendTestMail(c *gin.Context) { - var service admin.MailTestService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Send() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*admin.TestSMTPService](c, admin.TestSMTPParamCtx{}) + err := service.Test(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminTestThumbGenerator Tests thumb generator -func AdminTestThumbGenerator(c *gin.Context) { - var service admin.ThumbGeneratorTestService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Test(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminCreatePolicy(c *gin.Context) { + service := ParametersFromContext[*admin.CreateStoragePolicyService](c, admin.CreateStoragePolicyParamCtx{}) + res, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminTestAria2 测试aria2连接 -func AdminTestAria2(c *gin.Context) { - var service admin.Aria2TestService - if err := c.ShouldBindJSON(&service); err == nil { - var res serializer.Response - if service.Type == model.MasterNodeType { - res = service.TestMaster() - } else { - res = service.TestSlave() - } +func AdminUpdatePolicy(c *gin.Context) { + service := ParametersFromContext[*admin.UpdateStoragePolicyService](c, admin.UpdateStoragePolicyParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListNodes(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Nodes(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListPolicy 列出存储策略 -func AdminListPolicy(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Policies() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetNode(c *gin.Context) { + service := ParametersFromContext[*admin.SingleNodeService](c, admin.SingleNodeParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminTestPath 测试本地路径可用性 -func AdminTestPath(c *gin.Context) { - var service admin.PathTestService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Test() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminClearEntityUrlCache(c *gin.Context) { + admin.ClearEntityUrlCache(c) + c.JSON(200, serializer.Response{}) +} + +func AdminCreateStoragePolicyCors(c *gin.Context) { + service := ParametersFromContext[*admin.CreateStoragePolicyCorsService](c, admin.CreateStoragePolicyCorsParamCtx{}) + err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{}) } -// AdminTestSlave 测试从机可用性 -func AdminTestSlave(c *gin.Context) { - var service admin.SlaveTestService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Test() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminOdOAuthURL(c *gin.Context) { + service := ParametersFromContext[*admin.GetOauthRedirectService](c, admin.GetOauthRedirectParamCtx{}) + res, err := service.GetOAuth(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminAddPolicy 新建存储策略 -func AdminAddPolicy(c *gin.Context) { - var service admin.AddPolicyService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Add() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetPolicyOAuthCallbackURL(c *gin.Context) { + res := admin.GetPolicyOAuthURL(c) + c.JSON(200, serializer.Response{Data: res}) +} + +func AdminGetPolicyOAuthStatus(c *gin.Context) { + service := ParametersFromContext[*admin.SingleStoragePolicyService](c, admin.GetStoragePolicyParamCtx{}) + res, err := service.GetOauthCredentialStatus(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminAddCORS 创建跨域策略 -func AdminAddCORS(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.AddCORS() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminFinishOauthCallback(c *gin.Context) { + service := ParametersFromContext[*admin.FinishOauthCallbackService](c, admin.FinishOauthCallbackParamCtx{}) + err := service.Finish(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminAddSCF 创建回调函数 -func AdminAddSCF(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.AddSCF() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetSharePointDriverRoot(c *gin.Context) { + service := ParametersFromContext[*admin.SingleStoragePolicyService](c, admin.GetStoragePolicyParamCtx{}) + res, err := service.GetSharePointDriverRoot(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminOAuthURL 获取 OneDrive OAuth URL -func AdminOAuthURL(policyType string) gin.HandlerFunc { - return func(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindUri(&service); err == nil { - res := service.GetOAuth(c, policyType) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } +func AdminDeletePolicy(c *gin.Context) { + service := ParametersFromContext[*admin.SingleStoragePolicyService](c, admin.GetStoragePolicyParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminGetPolicy 获取存储策略详情 -func AdminGetPolicy(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetGroup(c *gin.Context) { + service := ParametersFromContext[*admin.SingleGroupService](c, admin.SingleGroupParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeletePolicy 删除存储策略 -func AdminDeletePolicy(c *gin.Context) { - var service admin.PolicyService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminCreateGroup(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertGroupService](c, admin.UpsertGroupParamCtx{}) + res, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListGroup 列出用户组 -func AdminListGroup(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Groups() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminUpdateGroup(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertGroupService](c, admin.UpsertGroupParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminAddGroup 新建用户组 -func AdminAddGroup(c *gin.Context) { - var service admin.AddGroupService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Add() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListUsers(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Users(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteGroup 删除用户组 +func AdminGetUser(c *gin.Context) { + service := ParametersFromContext[*admin.SingleUserService](c, admin.SingleUserParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +func AdminUpdateUser(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertUserService](c, admin.UpsertUserParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +func AdminCreateUser(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertUserService](c, admin.UpsertUserParamCtx{}) + res, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + c.JSON(200, serializer.Response{Data: res}) +} + +// func AdminHashIDEncode(c *gin.Context) { +// service := ParametersFromContext[*admin.HashIDService](c, admin.HashIDParamCtx{}) +// resp, err := service.Encode(c) +// if err != nil { +// c.JSON(200, serializer.Err(c, err)) +// c.Abort() +// return +// } +// +// c.JSON(200, serializer.Response{ +// Data: resp, +// }) +// } +// +// func AdminHashIDDecode(c *gin.Context) { +// service := ParametersFromContext[*admin.HashIDService](c, admin.HashIDParamCtx{}) +// resp, err := service.Decode(c) +// if err != nil { +// c.JSON(200, serializer.Err(c, err)) +// c.Abort() +// return +// } +// +// c.JSON(200, serializer.Response{ +// Data: resp, +// }) +// } +// +// func AdminBsEncode(c *gin.Context) { +// service := ParametersFromContext[*admin.BsEncodeService](c, admin.BsEncodeParamCtx{}) +// resp, err := service.Encode(c) +// if err != nil { +// c.JSON(200, serializer.Err(c, err)) +// c.Abort() +// return +// } +// +// c.JSON(200, serializer.Response{ +// Data: resp, +// }) +// } +// +// func AdminBsDecode(c *gin.Context) { +// service := ParametersFromContext[*admin.BsDecodeService](c, admin.BsDecodeParamCtx{}) +// resp, err := service.Decode(c) +// if err != nil { +// c.JSON(200, serializer.Err(c, err)) +// c.Abort() +// return +// } +// +// c.JSON(200, serializer.Response{ +// Data: resp, +// }) +// } +// + func AdminDeleteGroup(c *gin.Context) { - var service admin.GroupService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*admin.SingleGroupService](c, admin.SingleGroupParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminGetGroup 获取用户组详情 -func AdminGetGroup(c *gin.Context) { - var service admin.GroupService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminTestSlave 测试从机可用性 +func AdminTestSlave(c *gin.Context) { + service := ParametersFromContext[*admin.TestNodeService](c, admin.TestNodeParamCtx{}) + err := service.Test(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminListUser 列出用户 -func AdminListUser(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Users() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminTestDownloader 测试下载器连接 +func AdminTestDownloader(c *gin.Context) { + service := ParametersFromContext[*admin.TestNodeDownloaderService](c, admin.TestNodeDownloaderParamCtx{}) + res, err := service.Test(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminAddUser 新建用户组 -func AdminAddUser(c *gin.Context) { - var service admin.AddUserService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Add() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminCreateNode(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertNodeService](c, admin.UpsertNodeParamCtx{}) + res, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminGetUser 获取用户详情 -func AdminGetUser(c *gin.Context) { - var service admin.UserService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminUpdateNode(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertNodeService](c, admin.UpsertNodeParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteUser 批量删除用户 -func AdminDeleteUser(c *gin.Context) { - var service admin.UserBatchService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminDeleteNode(c *gin.Context) { + service := ParametersFromContext[*admin.SingleNodeService](c, admin.SingleNodeParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminBanUser 封禁/解封用户 -func AdminBanUser(c *gin.Context) { - var service admin.UserService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Ban() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// AdminDeleteUser 批量删除用户 +func AdminDeleteUser(c *gin.Context) { + service := ParametersFromContext[*admin.BatchUserService](c, admin.BatchUserParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminListFile 列出文件 -func AdminListFile(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Files() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListFiles(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Files(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminGetFile 获取文件 func AdminGetFile(c *gin.Context) { - var service admin.FileService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get(c) - // 是否需要重定向 - if res.Code == -301 { - c.Redirect(302, res.Data.(string)) - return - } - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*admin.SingleFileService](c, admin.SingleFileParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteFile 批量删除文件 -func AdminDeleteFile(c *gin.Context) { - var service admin.FileBatchService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminUpdateFile(c *gin.Context) { + service := ParametersFromContext[*admin.UpsertFileService](c, admin.UpsertFileParamCtx{}) + res, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListShare 列出分享 -func AdminListShare(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Shares() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetFileUrl(c *gin.Context) { + service := ParametersFromContext[*admin.SingleFileService](c, admin.SingleFileParamCtx{}) + res, err := service.Url(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteShare 批量删除分享 -func AdminDeleteShare(c *gin.Context) { - var service admin.ShareBatchService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminBatchDeleteFile(c *gin.Context) { + service := ParametersFromContext[*admin.BatchFileService](c, admin.BatchFileParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{}) } -// AdminListDownload 列出离线下载任务 -func AdminListDownload(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Downloads() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListEntities(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Entities(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteDownload 批量删除任务 -func AdminDeleteDownload(c *gin.Context) { - var service admin.TaskBatchService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetEntity(c *gin.Context) { + service := ParametersFromContext[*admin.SingleEntityService](c, admin.SingleEntityParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListTask 列出任务 -func AdminListTask(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Tasks() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetEntityUrl(c *gin.Context) { + service := ParametersFromContext[*admin.SingleEntityService](c, admin.SingleEntityParamCtx{}) + res, err := service.Url(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteTask 批量删除任务 -func AdminDeleteTask(c *gin.Context) { - var service admin.TaskBatchService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.DeleteGeneral(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminBatchDeleteEntity(c *gin.Context) { + service := ParametersFromContext[*admin.BatchEntityService](c, admin.BatchEntityParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } } -// AdminCreateImportTask 新建文件导入任务 -func AdminCreateImportTask(c *gin.Context) { - var service admin.ImportTaskService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListTasks(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Tasks(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListFolders 列出用户或外部文件系统目录 -func AdminListFolders(c *gin.Context) { - var service admin.ListFolderService - if err := c.ShouldBindUri(&service); err == nil { - res := service.List(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetTask(c *gin.Context) { + service := ParametersFromContext[*admin.SingleTaskService](c, admin.SingleTaskParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminListNodes 列出从机节点 -func AdminListNodes(c *gin.Context) { - var service admin.AdminListService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Nodes() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminBatchDeleteTask(c *gin.Context) { + service := ParametersFromContext[*admin.BatchTaskService](c, admin.BatchTaskParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{}) } -// AdminAddNode 新建节点 -func AdminAddNode(c *gin.Context) { - var service admin.AddNodeService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Add() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminListShares(c *gin.Context) { + service := ParametersFromContext[*admin.AdminListService](c, admin.AdminListServiceParamsCtx{}) + res, err := service.Shares(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminToggleNode 启用/暂停节点 -func AdminToggleNode(c *gin.Context) { - var service admin.ToggleNodeService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Toggle() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminGetShare(c *gin.Context) { + service := ParametersFromContext[*admin.SingleShareService](c, admin.SingleShareParamCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } -// AdminDeleteGroup 删除用户组 -func AdminDeleteNode(c *gin.Context) { - var service admin.NodeService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminBatchDeleteShare(c *gin.Context) { + service := ParametersFromContext[*admin.BatchShareService](c, admin.BatchShareParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } } -// AdminGetNode 获取节点详情 -func AdminGetNode(c *gin.Context) { - var service admin.NodeService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func AdminCalibrateStorage(c *gin.Context) { + service := ParametersFromContext[*admin.SingleUserService](c, admin.SingleUserParamCtx{}) + res, err := service.CalibrateStorage(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } + c.JSON(200, serializer.Response{Data: res}) } diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go deleted file mode 100644 index 38871795..00000000 --- a/routers/controllers/aria2.go +++ /dev/null @@ -1,97 +0,0 @@ -package controllers - -import ( - "context" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/service/aria2" - "github.com/cloudreve/Cloudreve/v3/service/explorer" - "github.com/gin-gonic/gin" -) - -// AddAria2URL 添加离线下载URL -func AddAria2URL(c *gin.Context) { - var addService aria2.BatchAddURLService - if err := c.ShouldBindJSON(&addService); err == nil { - res := addService.Add(c, common.URLTask) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// SelectAria2File 选择多文件离线下载中要下载的文件 -func SelectAria2File(c *gin.Context) { - var selectService aria2.SelectFileService - if err := c.ShouldBindJSON(&selectService); err == nil { - res := selectService.Select(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// AddAria2Torrent 添加离线下载种子 -func AddAria2Torrent(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - // 获取种子内容的下载地址 - res := service.CreateDownloadSession(ctx, c) - if res.Code != 0 { - c.JSON(200, res) - return - } - - // 创建下载任务 - var addService aria2.AddURLService - addService.URL = res.Data.(string) - - if err := c.ShouldBindJSON(&addService); err == nil { - addService.URL = res.Data.(string) - res := addService.Add(c, nil, common.URLTask) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } - - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// CancelAria2Download 取消或删除aria2离线下载任务 -func CancelAria2Download(c *gin.Context) { - var selectService aria2.DownloadTaskService - if err := c.ShouldBindUri(&selectService); err == nil { - res := selectService.Delete(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// ListDownloading 获取正在下载中的任务 -func ListDownloading(c *gin.Context) { - var service aria2.DownloadListService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Downloading(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// ListFinished 获取已完成的任务 -func ListFinished(c *gin.Context) { - var service aria2.DownloadListService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Finished(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} diff --git a/routers/controllers/callback.go b/routers/controllers/callback.go index fec5d07f..1b1d0373 100644 --- a/routers/controllers/callback.go +++ b/routers/controllers/callback.go @@ -1,140 +1,126 @@ package controllers import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "path" - "strconv" - - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/service/callback" + "fmt" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/upyun" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/service/callback" "github.com/gin-gonic/gin" + "github.com/qiniu/go-sdk/v7/auth/qbox" ) -// RemoteCallback 远程上传回调 -func RemoteCallback(c *gin.Context) { - var callbackBody callback.RemoteUploadCallbackService - if err := c.ShouldBindJSON(&callbackBody); err == nil { - res := callback.ProcessCallback(callbackBody, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// RemoteCallback process callback request to complete upload +func ProcessCallback(failedStatusCode int, generalResp bool) gin.HandlerFunc { + return func(c *gin.Context) { + err := callback.ProcessCallback(c) + if err != nil { + if generalResp { + c.JSON(failedStatusCode, serializer.GeneralUploadCallbackFailed{Error: err.Error()}) + } else { + c.JSON(failedStatusCode, serializer.Err(c, err)) + } + return + } + + c.JSON(200, serializer.Response{}) } } -// QiniuCallback 七牛上传回调 -func QiniuCallback(c *gin.Context) { - var callbackBody callback.UploadCallbackService - if err := c.ShouldBindJSON(&callbackBody); err == nil { - res := callback.ProcessCallback(callbackBody, c) - if res.Code != 0 { - c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: res.Msg}) - } else { - c.JSON(200, res) - } - } else { - c.JSON(401, ErrorResponse(err)) +// QiniuCallbackAuth 七牛回调签名验证 +func QiniuCallbackValidate(c *gin.Context) { + session := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession) + + // 验证回调是否来自qiniu + mac := qbox.NewMac(session.Policy.AccessKey, session.Policy.SecretKey) + ok, err := mac.VerifyCallback(c.Request) + if err != nil { + util.Log().Debug("Failed to verify callback request: %s", err) + c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."}) + c.Abort() + return } -} -// OSSCallback 阿里云OSS上传回调 -func OSSCallback(c *gin.Context) { - var callbackBody callback.UploadCallbackService - if err := c.ShouldBindJSON(&callbackBody); err == nil { - if callbackBody.PicInfo == "," { - callbackBody.PicInfo = "" - } - res := callback.ProcessCallback(callbackBody, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + if !ok { + c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Invalid signature."}) + c.Abort() + return } + + c.Next() } -// UpyunCallback 又拍云上传回调 -func UpyunCallback(c *gin.Context) { - var callbackBody callback.UpyunCallbackService - if err := c.ShouldBind(&callbackBody); err == nil { - if callbackBody.Code != 200 { - util.Log().Debug( - "Upload callback returned error code:%d, message: %s", - callbackBody.Code, - callbackBody.Message, - ) +// OSSCallbackValidate 阿里云OSS上传回调 +func OSSCallbackValidate(c *gin.Context) { + var callbackBody callback.UploadCallbackService + if err := c.ShouldBindJSON(&callbackBody); err == nil { + uploadSession := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession) + if uploadSession.Props.Size != callbackBody.Size { + l := logging.FromContext(c) + l.Error("Callback validate failed: size mismatch, expected: %d, actual:%d", uploadSession.Props.Size, callbackBody.Size) + c.JSON(401, + serializer.GeneralUploadCallbackFailed{ + Error: fmt.Sprintf("size mismatch"), + }) + c.Abort() return } - res := callback.ProcessCallback(callbackBody, c) - c.JSON(200, res) + + c.Next() } else { - c.JSON(200, ErrorResponse(err)) + c.JSON(401, ErrorResponse(err)) + c.Abort() } } -// OneDriveCallback OneDrive上传完成客户端回调 -func OneDriveCallback(c *gin.Context) { - var callbackBody callback.OneDriveCallback - if err := c.ShouldBindJSON(&callbackBody); err == nil { - res := callbackBody.PreProcess(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// UpyunCallbackAuth 又拍云回调签名验证 +func UpyunCallbackAuth(c *gin.Context) { + uploadSession := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession) + l := logging.FromContext(c) + if err := upyun.ValidateCallback(c, uploadSession); err != nil { + l.Error("Failed to verify callback request: %s", err) + + c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."}) } + + c.Next() } // OneDriveOAuth OneDrive 授权回调 func OneDriveOAuth(c *gin.Context) { - var callbackBody callback.OauthService - if err := c.ShouldBindQuery(&callbackBody); err == nil { - res := callbackBody.OdAuth(c) - redirect := model.GetSiteURL() - redirect.Path = path.Join(redirect.Path, "/admin/policy") - queries := redirect.Query() - queries.Add("code", strconv.Itoa(res.Code)) - queries.Add("msg", res.Msg) - queries.Add("err", res.Error) - redirect.RawQuery = queries.Encode() - c.Redirect(303, redirect.String()) - } else { - c.JSON(200, ErrorResponse(err)) - } + //var callbackBody callback.OauthService + //if err := c.ShouldBindQuery(&callbackBody); err == nil { + // res := callbackBody.OdAuth(c) + // redirect := model.GetSiteURL() + // redirect.Path = path.Join(redirect.Path, "/admin/policy") + // queries := redirect.Query() + // queries.Add("code", strconv.Itoa(res.Code)) + // queries.Add("msg", res.Msg) + // queries.Add("err", res.Error) + // redirect.RawQuery = queries.Encode() + // c.Redirect(303, redirect.String()) + //} else { + // c.JSON(200, ErrorResponse(err)) + //} } // GoogleDriveOAuth Google Drive 授权回调 func GoogleDriveOAuth(c *gin.Context) { - var callbackBody callback.OauthService - if err := c.ShouldBindQuery(&callbackBody); err == nil { - res := callbackBody.GDriveAuth(c) - redirect := model.GetSiteURL() - redirect.Path = path.Join(redirect.Path, "/admin/policy") - queries := redirect.Query() - queries.Add("code", strconv.Itoa(res.Code)) - queries.Add("msg", res.Msg) - queries.Add("err", res.Error) - redirect.RawQuery = queries.Encode() - c.Redirect(303, redirect.String()) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// COSCallback COS上传完成客户端回调 -func COSCallback(c *gin.Context) { - var callbackBody callback.COSCallback - if err := c.ShouldBindQuery(&callbackBody); err == nil { - res := callbackBody.PreProcess(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// S3Callback S3上传完成客户端回调 -func S3Callback(c *gin.Context) { - var callbackBody callback.S3Callback - if err := c.ShouldBindQuery(&callbackBody); err == nil { - res := callbackBody.PreProcess(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + //var callbackBody callback.OauthService + //if err := c.ShouldBindQuery(&callbackBody); err == nil { + // res := callbackBody.GDriveAuth(c) + // redirect := model.GetSiteURL() + // redirect.Path = path.Join(redirect.Path, "/admin/policy") + // queries := redirect.Query() + // queries.Add("code", strconv.Itoa(res.Code)) + // queries.Add("msg", res.Msg) + // queries.Add("err", res.Error) + // redirect.RawQuery = queries.Encode() + // c.Redirect(303, redirect.String()) + //} else { + // c.JSON(200, ErrorResponse(err)) + //} } diff --git a/routers/controllers/directory.go b/routers/controllers/directory.go index a2e06851..10ebf2b8 100644 --- a/routers/controllers/directory.go +++ b/routers/controllers/directory.go @@ -1,28 +1,27 @@ package controllers import ( - "github.com/cloudreve/Cloudreve/v3/service/explorer" + "errors" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/explorer" "github.com/gin-gonic/gin" ) -// CreateDirectory 创建目录 -func CreateDirectory(c *gin.Context) { - var service explorer.DirectoryService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.CreateDirectory(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - // ListDirectory 列出目录下内容 func ListDirectory(c *gin.Context) { - var service explorer.DirectoryService - if err := c.ShouldBindUri(&service); err == nil { - res := service.ListDirectory(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*explorer.ListFileService](c, explorer.ListFileParameterCtx{}) + resp, err := service.List(c) + if err != nil { + if errors.Is(err, explorer.ErrSSETakeOver) { + return + } + + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{ + Data: resp, + }) } diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 0e7c2062..8b70c5fa 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -1,421 +1,373 @@ package controllers import ( - "context" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "net/http" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/service/explorer" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/explorer" "github.com/gin-gonic/gin" ) func DownloadArchive(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ArchiveService - if err := c.ShouldBindUri(&service); err == nil { - service.DownloadArchived(ctx, c) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*explorer.ArchiveService](c, explorer.ArchiveParamCtx{}) + err := service.DownloadArchived(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } } -func Archive(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// CreateArchive 创建文件压缩任务 +func CreateArchive(c *gin.Context) { + service := ParametersFromContext[*explorer.ArchiveWorkflowService](c, explorer.CreateArchiveParamCtx{}) + resp, err := service.CreateCompressTask(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } - var service explorer.ItemIDService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Archive(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } -// Compress 创建文件压缩任务 -func Compress(c *gin.Context) { - var service explorer.ItemCompressService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.CreateCompressTask(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// CreateRemoteDownload creates remote download task +func CreateRemoteDownload(c *gin.Context) { + service := ParametersFromContext[*explorer.DownloadWorkflowService](c, explorer.CreateDownloadParamCtx{}) + resp, err := service.CreateDownloadTask(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// Decompress 创建文件解压缩任务 -func Decompress(c *gin.Context) { - var service explorer.ItemDecompressService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.CreateDecompressTask(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } -// AnonymousGetContent 匿名获取文件资源 -func AnonymousGetContent(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.FileAnonymousGetService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Download(ctx, c) - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// ExtractArchive creates extract archive task +func ExtractArchive(c *gin.Context) { + service := ParametersFromContext[*explorer.ArchiveWorkflowService](c, explorer.CreateArchiveParamCtx{}) + resp, err := service.CreateExtractTask(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// AnonymousPermLink Deprecated 文件签名后的永久链接 -func AnonymousPermLinkDeprecated(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.FileAnonymousGetService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Source(ctx, c) - // 是否需要重定向 - if res.Code == -302 { - c.Redirect(302, res.Data.(string)) - return - } - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } // AnonymousPermLink 文件中转后的永久直链接 func AnonymousPermLink(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sourceLinkRaw, ok := c.Get("source_link") - if !ok { - c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) + name := c.Param("name") + if err := explorer.RedirectDirectLink(c, name); err != nil { + c.JSON(404, serializer.Err(c, err)) + c.Abort() return } +} - sourceLink := sourceLinkRaw.(*model.SourceLink) - - service := &explorer.FileAnonymousGetService{ - ID: sourceLink.FileID, - Name: sourceLink.File.Name, +// GetSource 获取文件的外链地址 +func GetSource(c *gin.Context) { + service := ParametersFromContext[*explorer.GetDirectLinkService](c, explorer.GetDirectLinkParamCtx{}) + res, err := service.Get(c) + if err != nil && len(res) == 0 { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } - res := service.Source(ctx, c) - // 是否需要重定向 - if res.Code == -302 { - c.Redirect(302, res.Data.(string)) + if err != nil { + // Not fully completed + errResp := serializer.Err(c, err) + errResp.Data = res + c.JSON(200, errResp) return } - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) + c.JSON(200, serializer.Response{Data: res}) +} + +// Thumb 获取文件缩略图 +func Thumb(c *gin.Context) { + service := ParametersFromContext[*explorer.FileThumbService](c, explorer.FileThumbParameterCtx{}) + res, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + c.JSON(200, serializer.Response{Data: res}) } -func GetSource(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// FileURL get temporary file url for preview or download +func FileURL(c *gin.Context) { + service := ParametersFromContext[*explorer.FileURLService](c, explorer.FileURLParameterCtx{}) + resp, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } - var service explorer.ItemIDService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Sources(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } -// Thumb 获取文件缩略图 -func Thumb(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - fs, err := filesystem.NewFileSystemFromContext(c) +// ServeEntity download entity content +func ServeEntity(c *gin.Context) { + service := ParametersFromContext[*explorer.EntityDownloadService](c, explorer.EntityDownloadParameterCtx{}) + err := service.Serve(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - defer fs.Recycle() +} - // 获取文件ID - fileID, ok := c.Get("object_id") - if !ok { - c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", err)) +// CreateViewerSession creates a viewer session +func CreateViewerSession(c *gin.Context) { + service := ParametersFromContext[*explorer.CreateViewerSessionService](c, explorer.CreateViewerSessionParamCtx{}) + resp, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - // 获取缩略图 - resp, err := fs.GetThumb(ctx, fileID.(uint)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) + } +} + +// PutContent 更新文件内容 +func PutContent(c *gin.Context) { + service := ParametersFromContext[*explorer.FileUpdateService](c, explorer.FileUpdateParameterCtx{}) + res, err := service.PutContent(c, nil) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeNotSet, "Failed to get thumbnail", err)) + c.JSON(200, serializer.Err(c, err)) + request.BlackHole(c.Request.Body) + c.Abort() return } - if resp.Redirect { - c.Header("Cache-Control", fmt.Sprintf("max-age=%d", resp.MaxAge)) - c.Redirect(http.StatusMovedPermanently, resp.URL) + c.JSON(200, serializer.Response{Data: res}) +} + +// FileUpload 本地策略文件上传 +func FileUpload(c *gin.Context) { + service := ParametersFromContext[*explorer.UploadService](c, explorer.UploadParameterCtx{}) + err := service.LocalUpload(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + request.BlackHole(c.Request.Body) + c.Abort() return } - defer resp.Content.Close() - http.ServeContent(c.Writer, c.Request, "thumb."+model.GetSettingByNameWithDefault("thumb_encode_method", "jpg"), fs.FileTarget[0].UpdatedAt, resp.Content) - + c.JSON(200, serializer.Response{}) } -// Preview 预览文件 -func Preview(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - res := service.PreviewContent(ctx, c, false) - // 是否需要重定向 - if res.Code == -301 { - c.Redirect(302, res.Data.(string)) - return - } - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// DeleteUploadSession 删除上传会话 +func DeleteUploadSession(c *gin.Context) { + service := ParametersFromContext[*explorer.DeleteUploadSessionService](c, explorer.DeleteUploadSessionParameterCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } -// PreviewText 预览文本文件 -func PreviewText(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - res := service.PreviewContent(ctx, c, true) - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// CreateUploadSession 创建上传会话 +func CreateUploadSession(c *gin.Context) { + service := ParametersFromContext[*explorer.CreateUploadSessionService](c, explorer.CreateUploadSessionParameterCtx{}) + resp, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// GetDocPreview 获取DOC文件预览地址 -func GetDocPreview(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + c.JSON(200, serializer.Response{ + Data: resp, + }) +} - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - res := service.CreateDocPreviewSession(ctx, c, true) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// CreateFile 创建空白文件 +func CreateFile(c *gin.Context) { + service := ParametersFromContext[*explorer.CreateFileService](c, explorer.CreateFileParameterCtx{}) + resp, err := service.Create(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// CreateDownloadSession 创建文件下载会话 -func CreateDownloadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + c.JSON(200, serializer.Response{ + Data: resp, + }) +} - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - res := service.CreateDownloadSession(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// RenameFile Renames a file. +func RenameFile(c *gin.Context) { + service := ParametersFromContext[*explorer.RenameFileService](c, explorer.RenameFileParameterCtx{}) + resp, err := service.Rename(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{ + Data: resp, + }) } -// Download 文件下载 -func Download(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.DownloadService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Download(ctx, c) - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// MoveFile Moves or Copy files. +func MoveFile(c *gin.Context) { + service := ParametersFromContext[*explorer.MoveFileService](c, explorer.MoveFileParameterCtx{}) + if err := service.Move(c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// PutContent 更新文件内容 -func PutContent(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + c.JSON(200, serializer.Response{}) +} - var service explorer.FileIDService - if err := c.ShouldBindUri(&service); err == nil { - res := service.PutContent(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// Delete 删除文件或目录 +func Delete(c *gin.Context) { + service := ParametersFromContext[*explorer.DeleteFileService](c, explorer.DeleteFileParameterCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// FileUpload 本地策略文件上传 -func FileUpload(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.UploadService - if err := c.ShouldBindUri(&service); err == nil { - res := service.LocalUpload(ctx, c) - c.JSON(200, res) - request.BlackHole(c.Request.Body) - } else { - c.JSON(200, ErrorResponse(err)) - } - - //fileData := fsctx.FileStream{ - // MIMEType: c.Request.Header.Get("Content-Type"), - // File: c.Request.Body, - // Size: fileSize, - // Name: fileName, - // VirtualPath: filePath, - // Mode: fsctx.Create, - //} - // - //// 创建文件系统 - //fs, err := filesystem.NewFileSystemFromContext(c) - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) - // return - //} - // - //// 非可用策略时拒绝上传 - //if !fs.Policy.IsTransitUpload(fileSize) { - // request.BlackHole(c.Request.Body) - // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, "当前存储策略无法使用", nil)) - // return - //} - // - //// 给文件系统分配钩子 - //fs.Use("BeforeUpload", filesystem.HookValidateFile) - //fs.Use("BeforeUpload", filesystem.HookValidateCapacity) - //fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) - //fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) - //fs.Use("AfterUpload", filesystem.GenericAfterUpload) - //fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - //fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) - //fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity) - // - //// 执行上传 - //ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) - //uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) - //err = fs.Upload(uploadCtx, &fileData) - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) - // return - //} - // - //c.JSON(200, serializer.Response{ - // Code: 0, - //}) + c.JSON(200, serializer.Response{}) } -// DeleteUploadSession 删除上传会话 -func DeleteUploadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// Restore restore file or directory +func Restore(c *gin.Context) { + service := ParametersFromContext[*explorer.DeleteFileService](c, explorer.DeleteFileParameterCtx{}) + err := service.Restore(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) +} - var service explorer.UploadSessionService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// Unlock unlocks files by given tokens +func Unlock(c *gin.Context) { + service := ParametersFromContext[*explorer.UnlockFileService](c, explorer.UnlockFileParameterCtx{}) + err := service.Unlock(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } -// DeleteAllUploadSession 删除全部上传会话 -func DeleteAllUploadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// Pin pins files by given uri +func Pin(c *gin.Context) { + service := ParametersFromContext[*explorer.PinFileService](c, explorer.PinFileParameterCtx{}) + err := service.PinFile(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } - res := explorer.DeleteAllUploadSession(ctx, c) - c.JSON(200, res) + c.JSON(200, serializer.Response{}) } -// GetUploadSession 创建上传会话 -func GetUploadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// Unpin unpins files by given uri +func Unpin(c *gin.Context) { + service := ParametersFromContext[*explorer.PinFileService](c, explorer.PinFileParameterCtx{}) + err := service.UnpinFile(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return + } + + c.JSON(200, serializer.Response{}) +} - var service explorer.CreateUploadSessionService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// PatchMetadata patch metadata +func PatchMetadata(c *gin.Context) { + service := ParametersFromContext[*explorer.PatchMetadataService](c, explorer.PatchMetadataParameterCtx{}) + err := service.Patch(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } -// SearchFile 搜索文件 -func SearchFile(c *gin.Context) { - var service explorer.ItemSearchService - if err := c.ShouldBindUri(&service); err != nil { - c.JSON(200, ErrorResponse(err)) +// GetFileInfo gets file info +func GetFileInfo(c *gin.Context) { + service := ParametersFromContext[*explorer.GetFileInfoService](c, explorer.GetFileInfoParameterCtx{}) + resp, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - if err := c.ShouldBindQuery(&service); err != nil { - c.JSON(200, ErrorResponse(err)) + c.JSON(200, serializer.Response{ + Data: resp, + }) +} + +// SetCurrentVersion sets current version +func SetCurrentVersion(c *gin.Context) { + service := ParametersFromContext[*explorer.SetCurrentVersionService](c, explorer.SetCurrentVersionParamCtx{}) + err := service.Set(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - res := service.Search(c) - c.JSON(200, res) + c.JSON(200, serializer.Response{}) } -// CreateFile 创建空白文件 -func CreateFile(c *gin.Context) { - var service explorer.SingleFileService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// DeleteVersion deletes a version +func DeleteVersion(c *gin.Context) { + service := ParametersFromContext[*explorer.DeleteVersionService](c, explorer.DeleteVersionParamCtx{}) + err := service.Delete(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } diff --git a/routers/controllers/main.go b/routers/controllers/main.go index aaac8b22..aa0335a7 100644 --- a/routers/controllers/main.go +++ b/routers/controllers/main.go @@ -1,10 +1,9 @@ package controllers import ( + "context" "encoding/json" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" ) @@ -44,7 +43,7 @@ func ErrorResponse(err error) serializer.Response { // 处理 Validator 产生的错误 if ve, ok := err.(validator.ValidationErrors); ok { for _, e := range ve { - return serializer.ParamErr( + return serializer.ParamErrDeprecated( ParamErrorMsg(e.Field(), e.Tag()), err, ) @@ -52,18 +51,55 @@ func ErrorResponse(err error) serializer.Response { } if _, ok := err.(*json.UnmarshalTypeError); ok { - return serializer.ParamErr("JSON marshall error", err) + return serializer.ParamErrDeprecated("JSON marshall error", err) } - return serializer.ParamErr("Parameter error", err) + return serializer.ParamErrDeprecated("Parameter error", err) +} + +// FromJSON Parse and validate JSON from request body +func FromJSON[T any](ctxKey any) gin.HandlerFunc { + return func(c *gin.Context) { + var service T + if err := c.ShouldBindJSON(&service); err == nil { + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxKey, &service)) + c.Next() + } else { + c.JSON(200, ErrorResponse(err)) + c.Abort() + } + } } -// CurrentUser 获取当前用户 -func CurrentUser(c *gin.Context) *model.User { - if user, _ := c.Get("user"); user != nil { - if u, ok := user.(*model.User); ok { - return u +// FromQuery Parse and validate form from request query +func FromQuery[T any](ctxKey any) gin.HandlerFunc { + return func(c *gin.Context) { + var service T + if err := c.ShouldBindQuery(&service); err == nil { + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxKey, &service)) + c.Next() + } else { + c.JSON(200, ErrorResponse(err)) + c.Abort() } } - return nil +} + +// FromUri Parse and validate form from request uri +func FromUri[T any](ctxKey any) gin.HandlerFunc { + return func(c *gin.Context) { + var service T + if err := c.ShouldBindUri(&service); err == nil { + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxKey, &service)) + c.Next() + } else { + c.JSON(200, ErrorResponse(err)) + c.Abort() + } + } +} + +// ParametersFromContext retrieves request parameters from context +func ParametersFromContext[T any](c *gin.Context, ctxKey any) T { + return c.Request.Context().Value(ctxKey).(T) } diff --git a/routers/controllers/objects.go b/routers/controllers/objects.go deleted file mode 100644 index a6095b55..00000000 --- a/routers/controllers/objects.go +++ /dev/null @@ -1,84 +0,0 @@ -package controllers - -import ( - "context" - - "github.com/cloudreve/Cloudreve/v3/service/explorer" - "github.com/gin-gonic/gin" -) - -// Delete 删除文件或目录 -func Delete(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ItemIDService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// Move 移动文件或目录 -func Move(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ItemMoveService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Move(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// Copy 复制文件或目录 -func Copy(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ItemMoveService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Copy(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// Rename 重命名文件或目录 -func Rename(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ItemRenameService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Rename(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// Rename 重命名文件或目录 -func GetProperty(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.ItemPropertyService - service.ID = c.Param("id") - if err := c.ShouldBindQuery(&service); err == nil { - res := service.GetProperty(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} diff --git a/routers/controllers/share.go b/routers/controllers/share.go index 8795c1eb..3aa549ca 100644 --- a/routers/controllers/share.go +++ b/routers/controllers/share.go @@ -1,237 +1,78 @@ package controllers import ( - "context" - "path" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/service/share" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/share" "github.com/gin-gonic/gin" + "net/http" ) // CreateShare 创建分享 func CreateShare(c *gin.Context) { - var service share.ShareCreateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// GetShare 查看分享 -func GetShare(c *gin.Context) { - var service share.ShareGetService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Get(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// ListShare 列出分享 -func ListShare(c *gin.Context) { - var service share.ShareListService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.List(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// SearchShare 搜索分享 -func SearchShare(c *gin.Context) { - var service share.ShareListService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Search(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// UpdateShare 更新分享属性 -func UpdateShare(c *gin.Context) { - var service share.ShareUpdateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Update(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// DeleteShare 删除分享 -func DeleteShare(c *gin.Context) { - var service share.Service - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*share.ShareCreateService](c, share.ShareCreateParamCtx{}) + uri, err := service.Upsert(c, 0) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } -} -// GetShareDownload 创建分享下载会话 -func GetShareDownload(c *gin.Context) { - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - res := service.CreateDownloadSession(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{Data: uri}) } -// PreviewShare 预览分享文件内容 -func PreviewShare(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - res := service.PreviewContent(ctx, c, false) - // 是否需要重定向 - if res.Code == -301 { - c.Redirect(302, res.Data.(string)) - return - } - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// EditShare 编辑分享 +func EditShare(c *gin.Context) { + service := ParametersFromContext[*share.ShareCreateService](c, share.ShareCreateParamCtx{}) + uri, err := service.Upsert(c, hashid.FromContext(c)) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } -} -// PreviewShareText 预览文本文件 -func PreviewShareText(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - res := service.PreviewContent(ctx, c, true) - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{Data: uri}) } -// PreviewShareReadme 预览文本自述文件 -func PreviewShareReadme(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - // 自述文件名限制 - allowFileName := []string{"readme.txt", "readme.md"} - fileName := strings.ToLower(path.Base(service.Path)) - if !util.ContainsString(allowFileName, fileName) { - c.JSON(200, serializer.ParamErr("Not a README file", nil)) - } - - // 必须是目录分享 - if shareCtx, ok := c.Get("share"); ok { - if !shareCtx.(*model.Share).IsDir { - c.JSON(200, serializer.ParamErr("This share has no README file", nil)) - } - } - - res := service.PreviewContent(ctx, c, true) - // 是否有错误发生 - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// GetShare 查看分享 +func GetShare(c *gin.Context) { + service := ParametersFromContext[*share.ShareInfoService](c, share.ShareInfoParamCtx{}) + info, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } -} -// GetShareDocPreview 创建分享Office文档预览地址 -func GetShareDocPreview(c *gin.Context) { - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - res := service.CreateDocPreviewSession(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{Data: info}) } -// ListSharedFolder 列出分享的目录下的对象 -func ListSharedFolder(c *gin.Context) { - var service share.Service - if err := c.ShouldBindUri(&service); err == nil { - res := service.List(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// SearchSharedFolder 搜索分享的目录下的对象 -func SearchSharedFolder(c *gin.Context) { - var service share.SearchService - if err := c.ShouldBindUri(&service); err != nil { - c.JSON(200, ErrorResponse(err)) +// ListShare 列出分享 +func ListShare(c *gin.Context) { + service := ParametersFromContext[*share.ListShareService](c, share.ListShareParamCtx{}) + resp, err := service.List(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - if err := c.ShouldBindQuery(&service); err != nil { - c.JSON(200, ErrorResponse(err)) - return + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } - - res := service.Search(c) - c.JSON(200, res) } -// ArchiveShare 打包要下载的分享 -func ArchiveShare(c *gin.Context) { - var service share.ArchiveService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Archive(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// DeleteShare 删除分享 +func DeleteShare(c *gin.Context) { + err := share.DeleteShare(c, hashid.FromContext(c)) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + return } -} -// ShareThumb 获取分享目录下文件的缩略图 -func ShareThumb(c *gin.Context) { - var service share.Service - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Thumb(c) - if res.Code >= 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{}) } -// GetUserShare 查看给定用户的分享 -func GetUserShare(c *gin.Context) { - var service share.ShareUserGetService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.Get(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } +func ShareRedirect(c *gin.Context) { + service := ParametersFromContext[*share.ShortLinkRedirectService](c, share.ShortLinkRedirectParamCtx{}) + c.Redirect(http.StatusFound, service.RedirectTo(c)) } diff --git a/routers/controllers/site.go b/routers/controllers/site.go index c4a35080..38b803df 100644 --- a/routers/controllers/site.go +++ b/routers/controllers/site.go @@ -1,55 +1,33 @@ package controllers import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/basic" "github.com/gin-gonic/gin" - "github.com/mojocn/base64Captcha" ) // SiteConfig 获取站点全局配置 func SiteConfig(c *gin.Context) { - siteConfig := model.GetSettingByNames( - "siteName", - "login_captcha", - "reg_captcha", - "email_active", - "forget_captcha", - "email_active", - "themes", - "defaultTheme", - "home_view_method", - "share_view_method", - "authn_enabled", - "captcha_ReCaptchaKey", - "captcha_type", - "captcha_TCaptcha_CaptchaAppId", - "register_enabled", - "show_app_promotion", - ) + service := ParametersFromContext[*basic.GetSettingService](c, basic.GetSettingParamCtx{}) - var wopiExts []string - if wopi.Default != nil { - wopiExts = wopi.Default.AvailableExts() - } - - // 如果已登录,则同时返回用户信息和标签 - user, _ := c.Get("user") - if user, ok := user.(*model.User); ok { - c.JSON(200, serializer.BuildSiteConfig(siteConfig, user, wopiExts)) + resp, err := service.GetSiteConfig(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - c.JSON(200, serializer.BuildSiteConfig(siteConfig, nil, wopiExts)) + c.JSON(200, serializer.Response{ + Data: resp, + }) } // Ping 状态检查页面 func Ping(c *gin.Context) { - version := conf.BackendVersion - if conf.IsPro == "true" { + version := constants.BackendVersion + if constants.IsProBool { version += "-pro" } @@ -61,81 +39,41 @@ func Ping(c *gin.Context) { // Captcha 获取验证码 func Captcha(c *gin.Context) { - options := model.GetSettingByNames( - "captcha_IsShowHollowLine", - "captcha_IsShowNoiseDot", - "captcha_IsShowNoiseText", - "captcha_IsShowSlimeLine", - "captcha_IsShowSineLine", - ) - // 验证码配置 - var configD = base64Captcha.ConfigCharacter{ - Height: model.GetIntSetting("captcha_height", 60), - Width: model.GetIntSetting("captcha_width", 240), - //const CaptchaModeNumber:数字,CaptchaModeAlphabet:字母,CaptchaModeArithmetic:算术,CaptchaModeNumberAlphabet:数字字母混合. - Mode: model.GetIntSetting("captcha_mode", 3), - ComplexOfNoiseText: model.GetIntSetting("captcha_ComplexOfNoiseText", 0), - ComplexOfNoiseDot: model.GetIntSetting("captcha_ComplexOfNoiseDot", 0), - IsShowHollowLine: model.IsTrueVal(options["captcha_IsShowHollowLine"]), - IsShowNoiseDot: model.IsTrueVal(options["captcha_IsShowNoiseDot"]), - IsShowNoiseText: model.IsTrueVal(options["captcha_IsShowNoiseText"]), - IsShowSlimeLine: model.IsTrueVal(options["captcha_IsShowSlimeLine"]), - IsShowSineLine: model.IsTrueVal(options["captcha_IsShowSineLine"]), - CaptchaLen: model.GetIntSetting("captcha_CaptchaLen", 6), - } - - // 生成验证码 - idKeyD, capD := base64Captcha.GenerateCaptcha("", configD) - // 将验证码UID存入Session以便后续验证 - util.SetSession(c, map[string]interface{}{ - "captchaID": idKeyD, - }) - - // 将验证码图像编码为Base64 - base64stringD := base64Captcha.CaptchaWriteToBase64Encoding(capD) - c.JSON(200, serializer.Response{ Code: 0, - Data: base64stringD, + Data: basic.GetCaptchaImage(c), }) } // Manifest 获取manifest.json func Manifest(c *gin.Context) { - options := model.GetSettingByNames( - "siteName", - "siteTitle", - "pwa_small_icon", - "pwa_medium_icon", - "pwa_large_icon", - "pwa_display", - "pwa_theme_color", - "pwa_background_color", - ) - + settingClient := dependency.FromContext(c).SettingProvider() + siteOpts := settingClient.SiteBasic(c) + pwaOpts := settingClient.PWA(c) + c.Header("Cache-Control", "public, no-cache") c.JSON(200, map[string]interface{}{ - "short_name": options["siteName"], - "name": options["siteTitle"], + "short_name": siteOpts.Name, + "name": siteOpts.Name, "icons": []map[string]string{ { - "src": options["pwa_small_icon"], + "src": pwaOpts.SmallIcon, "sizes": "64x64 32x32 24x24 16x16", "type": "image/x-icon", }, { - "src": options["pwa_medium_icon"], + "src": pwaOpts.MediumIcon, "type": "image/png", "sizes": "192x192", }, { - "src": options["pwa_large_icon"], + "src": pwaOpts.LargeIcon, "type": "image/png", "sizes": "512x512", }, }, "start_url": ".", - "display": options["pwa_display"], - "theme_color": options["pwa_theme_color"], - "background_color": options["pwa_background_color"], + "display": pwaOpts.Display, + "theme_color": pwaOpts.ThemeColor, + "background_color": pwaOpts.BackgroundColor, }) } diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 2df36988..d5aba210 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -1,138 +1,118 @@ package controllers import ( - "context" - - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/service/admin" - "github.com/cloudreve/Cloudreve/v3/service/aria2" - "github.com/cloudreve/Cloudreve/v3/service/explorer" - "github.com/cloudreve/Cloudreve/v3/service/node" + "fmt" + + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/admin" + "github.com/cloudreve/Cloudreve/v4/service/explorer" + "github.com/cloudreve/Cloudreve/v4/service/node" "github.com/gin-gonic/gin" ) // SlaveUpload 从机文件上传 func SlaveUpload(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.UploadService - if err := c.ShouldBindUri(&service); err == nil { - res := service.SlaveUpload(ctx, c) - c.JSON(200, res) + service := ParametersFromContext[*explorer.UploadService](c, explorer.UploadParameterCtx{}) + err := service.SlaveUpload(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) request.BlackHole(c.Request.Body) - } else { - c.JSON(200, ErrorResponse(err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } // SlaveGetUploadSession 从机创建上传会话 func SlaveGetUploadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.SlaveCreateUploadSessionService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*explorer.SlaveCreateUploadSessionService](c, explorer.SlaveCreateUploadSessionParamCtx{}) + if err := service.Create(c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } // SlaveDeleteUploadSession 从机删除上传会话 func SlaveDeleteUploadSession(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.UploadSessionService - if err := c.ShouldBindUri(&service); err == nil { - res := service.SlaveDelete(ctx, c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*explorer.SlaveDeleteUploadSessionService](c, explorer.SlaveDeleteUploadSessionParamCtx{}) + if err := service.Delete(c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// SlaveDownload 从机文件下载,此请求返回的HTTP状态码不全为200 -func SlaveDownload(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + c.JSON(200, serializer.Response{}) +} - var service explorer.SlaveDownloadService - if err := c.ShouldBindUri(&service); err == nil { - res := service.ServeFile(ctx, c, true) - if res.Code != 0 { - c.JSON(400, res) - } - } else { - c.JSON(400, ErrorResponse(err)) +// SlaveServeEntity download entity content +func SlaveServeEntity(c *gin.Context) { + service := ParametersFromContext[*explorer.EntityDownloadService](c, explorer.EntityDownloadParameterCtx{}) + err := service.SlaveServe(c) + if err != nil { + c.JSON(400, serializer.Err(c, err)) + c.Abort() + return } } -// SlavePreview 从机文件预览 -func SlavePreview(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.SlaveDownloadService - if err := c.ShouldBindUri(&service); err == nil { - res := service.ServeFile(ctx, c, false) - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveMeta retrieve media metadata +func SlaveMeta(c *gin.Context) { + service := ParametersFromContext[*explorer.SlaveMetaService](c, explorer.SlaveMetaParamCtx{}) + res, err := service.MediaMeta(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, res)) } // SlaveThumb 从机文件缩略图 func SlaveThumb(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.SlaveFileService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Thumb(ctx, c) - if res.Code != 0 { - c.JSON(200, res) - } - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*explorer.SlaveThumbService](c, explorer.SlaveThumbParamCtx{}) + err := service.Thumb(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } } // SlaveDelete 从机删除 func SlaveDelete(c *gin.Context) { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var service explorer.SlaveFilesService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Delete(ctx, c) - c.JSON(200, res) + service := ParametersFromContext[*explorer.SlaveDeleteFileService](c, explorer.SlaveDeleteFileParamCtx{}) + if failed, err := service.Delete(c); err != nil { + c.JSON(200, serializer.NewResponseWithGobData(c, serializer.Response{ + Code: serializer.CodeNotFullySuccess, + Data: failed, + Msg: fmt.Sprintf("Failed to delete %d files(s)", len(failed)), + Error: err.Error(), + })) } else { - c.JSON(200, ErrorResponse(err)) + c.JSON(200, serializer.Response{Data: ""}) } } // SlavePing 从机测试 func SlavePing(c *gin.Context) { - var service admin.SlavePingService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Test() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*admin.SlavePingService](c, admin.SlavePingParameterCtx{}) + if err := service.Test(c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } // SlaveList 从机列出文件 @@ -146,101 +126,167 @@ func SlaveList(c *gin.Context) { } } -// SlaveHeartbeat 接受主机心跳包 -func SlaveHeartbeat(c *gin.Context) { - var service serializer.NodePingReq - if err := c.ShouldBindJSON(&service); err == nil { - res := node.HandleMasterHeartbeat(&service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveDownloadTaskCreate creates a download task on slave +func SlaveDownloadTaskCreate(c *gin.Context) { + service := ParametersFromContext[*slave.CreateSlaveDownload](c, node.CreateSlaveDownloadTaskParamCtx{}) + d := c.MustGet(downloader.DownloaderCtxKey).(downloader.Downloader) + handle, err := d.CreateTask(c, service.Url, service.Options) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, handle)) } -// SlaveAria2Create 创建 Aria2 任务 -func SlaveAria2Create(c *gin.Context) { - var service serializer.SlaveAria2Call - if err := c.ShouldBindJSON(&service); err == nil { - res := aria2.Add(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveDownloadTaskStatus 查询从机 Aria2 任务状态 +func SlaveDownloadTaskStatus(c *gin.Context) { + service := ParametersFromContext[*slave.GetSlaveDownload](c, node.GetSlaveDownloadTaskParamCtx{}) + d := c.MustGet(downloader.DownloaderCtxKey).(downloader.Downloader) + info, err := d.Info(c, service.Handle) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, info)) } -// SlaveAria2Status 查询从机 Aria2 任务状态 -func SlaveAria2Status(c *gin.Context) { - var service serializer.SlaveAria2Call - if err := c.ShouldBindJSON(&service); err == nil { - res := aria2.SlaveStatus(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveCancelDownloadTask 取消从机离线下载任务 +func SlaveCancelDownloadTask(c *gin.Context) { + service := ParametersFromContext[*slave.CancelSlaveDownload](c, node.CancelSlaveDownloadTaskParamCtx{}) + d := c.MustGet(downloader.DownloaderCtxKey).(downloader.Downloader) + err := d.Cancel(c, service.Handle) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } -// SlaveCancelAria2Task 取消从机离线下载任务 -func SlaveCancelAria2Task(c *gin.Context) { - var service serializer.SlaveAria2Call - if err := c.ShouldBindJSON(&service); err == nil { - res := aria2.SlaveCancel(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveSelectFilesToDownload 从机选取离线下载文件 +func SlaveSelectFilesToDownload(c *gin.Context) { + service := ParametersFromContext[*slave.SetSlaveFilesToDownload](c, node.SelectSlaveDownloadFilesParamCtx{}) + d := c.MustGet(downloader.DownloaderCtxKey).(downloader.Downloader) + err := d.SetFilesToDownload(c, service.Handle, service.Args...) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) } -// SlaveSelectTask 从机选取离线下载文件 -func SlaveSelectTask(c *gin.Context) { - var service serializer.SlaveAria2Call - if err := c.ShouldBindJSON(&service); err == nil { - res := aria2.SlaveSelect(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveTestDownloader 从机测试下载器连接 +func SlaveTestDownloader(c *gin.Context) { + d := c.MustGet(downloader.DownloaderCtxKey).(downloader.Downloader) + res, err := d.Test(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{Data: res}) } -// SlaveCreateTransferTask 从机创建中转任务 -func SlaveCreateTransferTask(c *gin.Context) { - var service serializer.SlaveTransferReq - if err := c.ShouldBindJSON(&service); err == nil { - res := explorer.CreateTransferTask(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveGetOauthCredential 从机获取主机的OneDrive存储策略凭证 +func SlaveGetCredential(c *gin.Context) { + service := ParametersFromContext[*node.OauthCredentialService](c, node.OauthCredentialParamCtx{}) + cred, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, cred)) } -// SlaveNotificationPush 处理从机发送的消息推送 -func SlaveNotificationPush(c *gin.Context) { - var service node.SlaveNotificationService - if err := c.ShouldBindUri(&service); err == nil { - res := service.HandleSlaveNotificationPush(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveCreateTask creates tasks and register it in registry +func SlaveCreateTask(c *gin.Context) { + service := ParametersFromContext[*cluster.CreateSlaveTask](c, node.CreateSlaveTaskParamCtx{}) + taskId, err := node.CreateTaskInSlave(service, c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, taskId)) } -// SlaveGetOauthCredential 从机获取主机的OneDrive存储策略凭证 -func SlaveGetOauthCredential(c *gin.Context) { - var service node.OauthCredentialService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// SlaveCreateTask creates tasks and register it in registry +func SlaveGetTask(c *gin.Context) { + service := ParametersFromContext[*node.GetSlaveTaskService](c, node.GetSlaveTaskParamCtx{}) + task, err := service.Get(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.NewResponseWithGobData(c, task)) } -// SlaveSelectTask 从机删除离线下载临时文件 -func SlaveDeleteTempFile(c *gin.Context) { - var service serializer.SlaveAria2Call - if err := c.ShouldBindJSON(&service); err == nil { - res := aria2.SlaveDeleteTemp(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +func SlaveCleanupFolder(c *gin.Context) { + service := ParametersFromContext[*cluster.FolderCleanup](c, node.FolderCleanupParamCtx{}) + if err := node.Cleanup(service, c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{}) +} + +func StatelessPrepareUpload(c *gin.Context) { + service := ParametersFromContext[*fs.StatelessPrepareUploadService](c, node.StatelessPrepareUploadParamCtx{}) + uploadSession, err := node.StatelessPrepareUpload(service, c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.NewResponseWithGobData(c, uploadSession)) +} + +func StatelessCompleteUpload(c *gin.Context) { + service := ParametersFromContext[*fs.StatelessCompleteUploadService](c, node.StatelessCompleteUploadParamCtx{}) + _, err := node.StatelessCompleteUpload(service, c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) +} + +func StatelessOnUploadFailed(c *gin.Context) { + service := ParametersFromContext[*fs.StatelessOnUploadFailedService](c, node.StatelessOnUploadFailedParamCtx{}) + if err := node.StatelessOnUploadFailed(service, c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) +} + +func StatelessCreateFile(c *gin.Context) { + service := ParametersFromContext[*fs.StatelessCreateFileService](c, node.StatelessCreateFileParamCtx{}) + if err := node.StatelessCreateFile(service, c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) } diff --git a/routers/controllers/tag.go b/routers/controllers/tag.go deleted file mode 100644 index 341841bd..00000000 --- a/routers/controllers/tag.go +++ /dev/null @@ -1,39 +0,0 @@ -package controllers - -import ( - "github.com/cloudreve/Cloudreve/v3/service/explorer" - "github.com/gin-gonic/gin" -) - -// CreateFilterTag 创建文件分类标签 -func CreateFilterTag(c *gin.Context) { - var service explorer.FilterTagCreateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// CreateLinkTag 创建目录快捷方式标签 -func CreateLinkTag(c *gin.Context) { - var service explorer.LinkTagCreateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// DeleteTag 删除标签 -func DeleteTag(c *gin.Context) { - var service explorer.TagService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} diff --git a/routers/controllers/user.go b/routers/controllers/user.go index 5d6301ee..ca7d1809 100644 --- a/routers/controllers/user.go +++ b/routers/controllers/user.go @@ -1,216 +1,173 @@ package controllers import ( - "encoding/json" - "fmt" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/authn" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/service/user" - "github.com/duo-labs/webauthn/webauthn" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/service/share" + "github.com/cloudreve/Cloudreve/v4/service/user" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) // StartLoginAuthn 开始注册WebAuthn登录 func StartLoginAuthn(c *gin.Context) { - userName := c.Param("username") - expectedUser, err := model.GetActiveUserByEmail(userName) + res, err := user.PreparePasskeyLogin(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeUserNotFound, "", err)) + c.JSON(200, serializer.Err(c, err)) return } - instance, err := authn.NewAuthnInstance() - if err != nil { - c.JSON(200, serializer.Err(serializer.CodeInitializeAuthn, "Cannot initialize authn", err)) - return - } - - options, sessionData, err := instance.BeginLogin(expectedUser) - - if err != nil { - c.JSON(200, ErrorResponse(err)) - return - } - - val, err := json.Marshal(sessionData) - if err != nil { - c.JSON(200, ErrorResponse(err)) - return - } - - util.SetSession(c, map[string]interface{}{ - "registration-session": val, - }) - c.JSON(200, serializer.Response{Code: 0, Data: options}) + c.JSON(200, serializer.Response{Data: res}) } // FinishLoginAuthn 完成注册WebAuthn登录 func FinishLoginAuthn(c *gin.Context) { - userName := c.Param("username") - expectedUser, err := model.GetActiveUserByEmail(userName) + service := ParametersFromContext[*user.FinishPasskeyLoginService](c, user.FinishPasskeyLoginParameterCtx{}) + u, err := service.FinishPasskeyLogin(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeUserNotFound, "", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - sessionDataJSON := util.GetSession(c, "registration-session").([]byte) - - var sessionData webauthn.SessionData - err = json.Unmarshal(sessionDataJSON, &sessionData) + util.WithValue(c, inventory.UserCtx{}, u) +} - instance, err := authn.NewAuthnInstance() +// StartRegAuthn 开始注册WebAuthn信息 +func StartRegAuthn(c *gin.Context) { + res, err := user.PreparePasskeyRegister(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeInitializeAuthn, "Cannot initialize authn", err)) + c.JSON(200, serializer.Err(c, err)) return } - _, err = instance.FinishLogin(expectedUser, sessionData, c.Request) + c.JSON(200, serializer.Response{Data: res}) +} +// FinishRegAuthn 完成注册WebAuthn信息 +func FinishRegAuthn(c *gin.Context) { + service := ParametersFromContext[*user.FinishPasskeyRegisterService](c, user.FinishPasskeyRegisterParameterCtx{}) + res, err := service.FinishPasskeyRegister(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeWebAuthnCredentialError, "Verification failed", err)) + c.JSON(200, serializer.Err(c, err)) return } - util.SetSession(c, map[string]interface{}{ - "user_id": expectedUser.ID, - }) - c.JSON(200, serializer.BuildUserResponse(expectedUser)) + c.JSON(200, serializer.Response{Data: res}) } -// StartRegAuthn 开始注册WebAuthn信息 -func StartRegAuthn(c *gin.Context) { - currUser := CurrentUser(c) - - instance, err := authn.NewAuthnInstance() +// UserDeletePasskey deletes user passkey +func UserDeletePasskey(c *gin.Context) { + service := ParametersFromContext[*user.DeletePasskeyService](c, user.DeletePasskeyParameterCtx{}) + err := service.DeletePasskey(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeInitializeAuthn, "Cannot initialize authn", err)) + c.JSON(200, serializer.Err(c, err)) return } - options, sessionData, err := instance.BeginRegistration(currUser) + c.JSON(200, serializer.Response{}) +} +// UserLoginValidation validates user login request +func UserLoginValidation(c *gin.Context) { + service := ParametersFromContext[*user.UserLoginService](c, user.LoginParameterCtx{}) + expectedUser, twoFaSession, err := service.Login(c) if err != nil { - c.JSON(200, ErrorResponse(err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - val, err := json.Marshal(sessionData) - if err != nil { - c.JSON(200, ErrorResponse(err)) + if twoFaSession == "" { + // No 2FA required, proceed + util.WithValue(c, inventory.UserCtx{}, expectedUser) + c.Next() return } - util.SetSession(c, map[string]interface{}{ - "registration-session": val, - }) - c.JSON(200, serializer.Response{Code: 0, Data: options}) + c.JSON(200, serializer.Response{Code: serializer.CodeNotFullySuccess, Data: twoFaSession}) + c.Abort() } -// FinishRegAuthn 完成注册WebAuthn信息 -func FinishRegAuthn(c *gin.Context) { - currUser := CurrentUser(c) - sessionDataJSON := util.GetSession(c, "registration-session").([]byte) - - var sessionData webauthn.SessionData - err := json.Unmarshal(sessionDataJSON, &sessionData) - - instance, err := authn.NewAuthnInstance() +// UserLogin2FAValidation validates user OTP code +func UserLogin2FAValidation(c *gin.Context) { + service := ParametersFromContext[*user.OtpValidationService](c, user.OtpValidationParameterCtx{}) + expectedUser, err := service.Verify2FA(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeInitializeAuthn, "Cannot initialize authn", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - credential, err := instance.FinishRegistration(currUser, sessionData, c.Request) + util.WithValue(c, inventory.UserCtx{}, expectedUser) + c.Next() +} +// UserIssueToken generates new token pair for user +func UserIssueToken(c *gin.Context) { + resp, err := user.IssueToken(c) if err != nil { - c.JSON(200, ErrorResponse(err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - err = currUser.RegisterAuthn(credential) + c.JSON(200, serializer.Response{ + Data: resp, + }) +} + +// UserRefreshToken refreshes token pair for user +func UserRefreshToken(c *gin.Context) { + service := ParametersFromContext[*user.RefreshTokenService](c, user.RefreshTokenParameterCtx{}) + resp, err := service.Refresh(c) if err != nil { - c.JSON(200, ErrorResponse(err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } c.JSON(200, serializer.Response{ - Code: 0, - Data: map[string]interface{}{ - "id": credential.ID, - "fingerprint": fmt.Sprintf("% X", credential.Authenticator.AAGUID), - }, + Data: resp, }) } -// UserLogin 用户登录 -func UserLogin(c *gin.Context) { - var service user.UserLoginService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Login(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - // UserRegister 用户注册 func UserRegister(c *gin.Context) { - var service user.UserRegisterService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Register(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - -// User2FALogin 用户二步验证登录 -func User2FALogin(c *gin.Context) { - var service user.Enable2FA - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Login(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + service := ParametersFromContext[*user.UserRegisterService](c, user.RegisterParameterCtx{}) + c.JSON(200, service.Register(c)) } // UserSendReset 发送密码重设邮件 func UserSendReset(c *gin.Context) { - var service user.UserResetEmailService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Reset(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*user.UserResetEmailService](c, user.UserResetEmailParameterCtx{}) + if err := service.Reset(c); err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + c.JSON(200, serializer.Response{}) } // UserReset 重设密码 func UserReset(c *gin.Context) { - var service user.UserResetService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Reset(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + service := ParametersFromContext[*user.UserResetService](c, user.UserResetParameterCtx{}) + res, err := service.Reset(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + c.JSON(200, serializer.Response{Data: res}) } // UserActivate 用户激活 func UserActivate(c *gin.Context) { - var service user.SettingService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Activate(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, user.ActivateUser(c)) } // UserSignOut 用户退出登录 @@ -221,178 +178,223 @@ func UserSignOut(c *gin.Context) { // UserMe 获取当前登录的用户 func UserMe(c *gin.Context) { - currUser := CurrentUser(c) - res := serializer.BuildUserResponse(*currUser) - c.JSON(200, res) + dep := dependency.FromContext(c) + c.JSON(200, serializer.Response{ + Data: user.BuildUser(inventory.UserFromContext(c), dep.HashIDEncoder()), + }) } -// UserStorage 获取用户的存储信息 -func UserStorage(c *gin.Context) { - currUser := CurrentUser(c) - res := serializer.BuildUserStorageResponse(*currUser) - c.JSON(200, res) -} +// UserGet 获取用户信息 +func UserGet(c *gin.Context) { + u, err := user.GetUser(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } -// UserTasks 获取任务队列 -func UserTasks(c *gin.Context) { - var service user.SettingListService - if err := c.ShouldBindQuery(&service); err == nil { - res := service.ListTasks(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + isAnonymous := inventory.IsAnonymousUser(inventory.UserFromContext(c)) + redactLevel := user.RedactLevelUser + if isAnonymous { + redactLevel = user.RedactLevelAnonymous } + c.JSON(200, serializer.Response{ + Data: user.BuildUserRedacted(u, redactLevel, dependency.FromContext(c).HashIDEncoder()), + }) } -// UserSetting 获取用户设定 -func UserSetting(c *gin.Context) { - var service user.SettingService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Settings(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// UserStorage 获取用户的存储信息 +func UserStorage(c *gin.Context) { + res, err := user.GetUserCapacity(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } + + c.JSON(200, serializer.Response{ + Data: res, + }) } -// UseGravatar 设定头像使用全球通用 -func UseGravatar(c *gin.Context) { - u := CurrentUser(c) - if err := u.Update(map[string]interface{}{"avatar": "gravatar"}); err != nil { - c.JSON(200, serializer.Err(serializer.CodeDBError, "无法更新头像", err)) +// UserSetting 获取用户设定 +func UserSetting(c *gin.Context) { + res, err := user.GetUserSettings(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - c.JSON(200, serializer.Response{}) + + c.JSON(200, serializer.Response{ + Data: res, + }) } // UploadAvatar 从文件上传头像 func UploadAvatar(c *gin.Context) { - // 取得头像上传大小限制 - maxSize := model.GetIntSetting("avatar_size", 2097152) - if c.Request.ContentLength == -1 || c.Request.ContentLength > int64(maxSize) { - request.BlackHole(c.Request.Body) - c.JSON(200, serializer.Err(serializer.CodeFileTooLarge, "", nil)) + if err := user.UpdateUserAvatar(c); err != nil { + c.JSON(200, serializer.Err(c, err)) return } - // 取得上传的文件 - file, err := c.FormFile("avatar") + c.JSON(200, serializer.Response{}) +} + +// GetUserAvatar 获取用户头像 +func GetUserAvatar(c *gin.Context) { + service := ParametersFromContext[*user.GetAvatarService](c, user.GetAvatarServiceParamsCtx{}) + err := service.Get(c) if err != nil { - c.JSON(200, serializer.ParamErr("Failed to read avatar file data", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } +} - // 初始化头像 - r, err := file.Open() +// UpdateOption 更改用户设定 +func UpdateOption(c *gin.Context) { + service := ParametersFromContext[*user.PatchUserSetting](c, user.PatchUserSettingParamsCtx{}) + err := service.Patch(c) if err != nil { - c.JSON(200, serializer.ParamErr("Failed to read avatar file data", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - avatar, err := thumb.NewThumbFromFile(r, file.Filename) + + c.JSON(200, serializer.Response{}) + + //var service user.SettingUpdateService + //if err := c.ShouldBindUri(&service); err == nil { + // var ( + // subService user.OptionsChangeHandler + // subErr error + // ) + // + // switch service.Option { + // case "nick": + // subService = &user.ChangerNick{} + // case "vip": + // subService = &user.VIPUnsubscribe{} + // case "qq": + // subService = &user.QQBind{} + // case "policy": + // subService = &user.PolicyChange{} + // case "homepage": + // subService = &user.HomePage{} + // case "password": + // subService = &user.PasswordChange{} + // case "2fa": + // subService = &user.Enable2FA{} + // case "authn": + // subService = &user.DeleteWebAuthn{} + // case "theme": + // subService = &user.ThemeChose{} + // default: + // subService = &user.ChangerNick{} + // } + // + // subErr = c.ShouldBindJSON(subService) + // if subErr != nil { + // c.JSON(200, ErrorResponse(subErr)) + // return + // } + // + // res := subService.Update(c, CurrentUser(c)) + // c.JSON(200, res) + // + //} else { + // c.JSON(200, ErrorResponse(err)) + //} +} + +// UserInit2FA 初始化二步验证 +func UserInit2FA(c *gin.Context) { + secret, err := user.Init2FA(c) if err != nil { - c.JSON(200, serializer.ParamErr("Invalid image", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - // 创建头像 - u := CurrentUser(c) - err = avatar.CreateAvatar(u.ID) + c.JSON(200, serializer.Response{ + Data: secret, + }) +} + +// UserPerformCopySession copy to create new session or refresh current session +func UserPerformCopySession(c *gin.Context) { + //var service user.CopySessionService + //if err := c.ShouldBindUri(&service); err == nil { + // res := service.Copy(c) + // c.JSON(200, res) + //} else { + // c.JSON(200, ErrorResponse(err)) + //} +} + +// UserPrepareLogin validates precondition for login +func UserPrepareLogin(c *gin.Context) { + service := ParametersFromContext[*user.PrepareLoginService](c, user.PrepareLoginParameterCtx{}) + res, err := service.Prepare(c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeIOFailed, "Failed to create avatar file", err)) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - // 保存头像标记 - if err := u.Update(map[string]interface{}{ - "avatar": "file", - }); err != nil { - c.JSON(200, serializer.DBErr("Failed to update avatar attribute", err)) + c.JSON(200, serializer.Response{Data: res}) +} + +// UserSearch Search user by keyword +func UserSearch(c *gin.Context) { + service := ParametersFromContext[*user.SearchUserService](c, user.SearchUserParamCtx{}) + u, err := service.Search(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - c.JSON(200, serializer.Response{}) + hasher := dependency.FromContext(c).HashIDEncoder() + c.JSON(200, serializer.Response{ + Data: lo.Map(u, func(item *ent.User, index int) user.User { + return user.BuildUserRedacted(item, user.RedactLevelUser, hasher) + }), + }) } -// GetUserAvatar 获取用户头像 -func GetUserAvatar(c *gin.Context) { - var service user.AvatarService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Get(c) - if res.Code == -301 { - // 重定向到gravatar - c.Redirect(301, res.Data.(string)) - } - } else { - c.JSON(200, ErrorResponse(err)) +// GetGroupList list all groups for options +func GetGroupList(c *gin.Context) { + u, err := user.ListAllGroups(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// UpdateOption 更改用户设定 -func UpdateOption(c *gin.Context) { - var service user.SettingUpdateService - if err := c.ShouldBindUri(&service); err == nil { - var ( - subService user.OptionsChangeHandler - subErr error - ) - - switch service.Option { - case "nick": - subService = &user.ChangerNick{} - case "homepage": - subService = &user.HomePage{} - case "password": - subService = &user.PasswordChange{} - case "2fa": - subService = &user.Enable2FA{} - case "authn": - subService = &user.DeleteWebAuthn{} - case "theme": - subService = &user.ThemeChose{} - default: - subService = &user.ChangerNick{} - } - - subErr = c.ShouldBindJSON(subService) - if subErr != nil { - c.JSON(200, ErrorResponse(subErr)) - return - } - - res := subService.Update(c, CurrentUser(c)) - c.JSON(200, res) - - } else { - c.JSON(200, ErrorResponse(err)) - } + hasher := dependency.FromContext(c).HashIDEncoder() + c.JSON(200, serializer.Response{ + Data: lo.Map(u, func(item *ent.Group, index int) *user.Group { + g := user.BuildGroup(item, hasher) + return user.RedactedGroup(g) + }), + }) } -// UserInit2FA 初始化二步验证 -func UserInit2FA(c *gin.Context) { - var service user.SettingService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Init2FA(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// ListPublicShare lists all public shares for given user +func ListPublicShare(c *gin.Context) { + service := ParametersFromContext[*share.ListShareService](c, share.ListShareParamCtx{}) + resp, err := service.ListInUserProfile(c, hashid.FromContext(c)) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// UserPrepareCopySession generates URL for copy session -func UserPrepareCopySession(c *gin.Context) { - var service user.CopySessionService - res := service.Prepare(c, CurrentUser(c)) - c.JSON(200, res) - -} - -// UserPerformCopySession copy to create new session or refresh current session -func UserPerformCopySession(c *gin.Context) { - var service user.CopySessionService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Copy(c) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } diff --git a/routers/controllers/webdav.go b/routers/controllers/webdav.go index 0453ada4..3b3033ab 100644 --- a/routers/controllers/webdav.go +++ b/routers/controllers/webdav.go @@ -1,104 +1,101 @@ package controllers import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/pkg/webdav" - "github.com/cloudreve/Cloudreve/v3/service/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/setting" "github.com/gin-gonic/gin" - "net/http" - "sync" ) -var handler *webdav.Handler +// ListDavAccounts lists all WebDAV accounts. +func ListDavAccounts(c *gin.Context) { + service := ParametersFromContext[*setting.ListDavAccountsService](c, setting.ListDavAccountParamCtx{}) + resp, err := service.List(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } -func init() { - handler = &webdav.Handler{ - Prefix: "/dav", - LockSystem: make(map[uint]webdav.LockSystem), - Mutex: &sync.Mutex{}, + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) } } -// ServeWebDAV 处理WebDAV相关请求 -func ServeWebDAV(c *gin.Context) { - fs, err := filesystem.NewFileSystemFromContext(c) +// CreateDAVAccounts 创建WebDAV账户 +func CreateDAVAccounts(c *gin.Context) { + service := ParametersFromContext[*setting.CreateDavAccountService](c, setting.CreateDavAccountParamCtx{}) + resp, err := service.Create(c) if err != nil { - util.Log().Warning("Failed to initialize filesystem for WebDAV,%s", err) + c.JSON(200, serializer.Err(c, err)) + c.Abort() return } - if webdavCtx, ok := c.Get("webdav"); ok { - application := webdavCtx.(*model.Webdav) - - // 重定根目录 - if application.Root != "/" { - if exist, root := fs.IsPathExist(application.Root); exist { - root.Position = "" - root.Name = "/" - fs.Root = root - } - } - - // 检查是否只读 - if application.Readonly { - switch c.Request.Method { - case "DELETE", "PUT", "MKCOL", "COPY", "MOVE": - c.Status(http.StatusForbidden) - return - } - } - - // 更新Context - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), fsctx.WebDAVCtx, application)) - } - - handler.ServeHTTP(c.Writer, c.Request, fs) + c.JSON(200, serializer.Response{ + Data: resp, + }) } -// GetWebDAVAccounts 获取webdav账号列表 -func GetWebDAVAccounts(c *gin.Context) { - var service setting.WebDAVListService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Accounts(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// UpdateDAVAccounts updates WebDAV accounts. +func UpdateDAVAccounts(c *gin.Context) { + service := ParametersFromContext[*setting.CreateDavAccountService](c, setting.CreateDavAccountParamCtx{}) + resp, err := service.Update(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// DeleteWebDAVAccounts 删除WebDAV账户 -func DeleteWebDAVAccounts(c *gin.Context) { - var service setting.WebDAVAccountService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Delete(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{ + Data: resp, + }) } -// UpdateWebDAVAccounts 更改WebDAV账户只读性和是否使用代理服务 -func UpdateWebDAVAccounts(c *gin.Context) { - var service setting.WebDAVAccountUpdateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Update(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) +// DeleteDAVAccounts deletes WebDAV accounts. +func DeleteDAVAccounts(c *gin.Context) { + err := setting.DeleteDavAccount(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return } -} -// CreateWebDAVAccounts 创建WebDAV账户 -func CreateWebDAVAccounts(c *gin.Context) { - var service setting.WebDAVAccountCreateService - if err := c.ShouldBindJSON(&service); err == nil { - res := service.Create(c, CurrentUser(c)) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } + c.JSON(200, serializer.Response{}) } + +// +//// DeleteWebDAVAccounts 删除WebDAV账户 +//func DeleteWebDAVAccounts(c *gin.Context) { +// var service setting.WebDAVAccountService +// if err := c.ShouldBindUri(&service); err == nil { +// res := service.Delete(c, CurrentUser(c)) +// c.JSON(200, res) +// } else { +// c.JSON(200, ErrorResponse(err)) +// } +//} +// +//// DeleteWebDAVMounts 删除WebDAV挂载 +//func DeleteWebDAVMounts(c *gin.Context) { +// var service setting.WebDAVListService +// if err := c.ShouldBindUri(&service); err == nil { +// res := service.Unmount(c, CurrentUser(c)) +// c.JSON(200, res) +// } else { +// c.JSON(200, ErrorResponse(err)) +// } +//} +// +// +//// CreateWebDAVMounts 创建WebDAV目录挂载 +//func CreateWebDAVMounts(c *gin.Context) { +// var service setting.WebDAVMountCreateService +// if err := c.ShouldBindJSON(&service); err == nil { +// res := service.Create(c, CurrentUser(c)) +// c.JSON(200, res) +// } else { +// c.JSON(200, ErrorResponse(err)) +// } +//} diff --git a/routers/controllers/wopi.go b/routers/controllers/wopi.go index 23eea155..2d65e50c 100644 --- a/routers/controllers/wopi.go +++ b/routers/controllers/wopi.go @@ -1,10 +1,8 @@ package controllers import ( - "context" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/cloudreve/Cloudreve/v3/service/explorer" + "github.com/cloudreve/Cloudreve/v4/pkg/wopi" + "github.com/cloudreve/Cloudreve/v4/service/explorer" "github.com/gin-gonic/gin" "net/http" ) @@ -35,43 +33,46 @@ func GetFile(c *gin.Context) { // PutFile Puts file content func PutFile(c *gin.Context) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - service := &explorer.FileIDService{} - res := service.PutContent(ctx, c) - switch res.Code { - case serializer.CodeFileTooLarge: - c.Status(http.StatusRequestEntityTooLarge) - c.Header(wopi.ServerErrorHeader, res.Error) - case serializer.CodeNotFound: - c.Status(http.StatusNotFound) - c.Header(wopi.ServerErrorHeader, res.Error) - case 0: - c.Status(http.StatusOK) - default: + service := &explorer.WopiService{} + err := service.PutContent(c) + if err != nil { c.Status(http.StatusInternalServerError) - c.Header(wopi.ServerErrorHeader, res.Error) + c.Header(wopi.ServerErrorHeader, err.Error()) } } // ModifyFile Modify file properties func ModifyFile(c *gin.Context) { action := c.GetHeader(wopi.OverwriteHeader) + var ( + service explorer.WopiService + err error + ) + switch action { - case wopi.MethodLock, wopi.MethodRefreshLock, wopi.MethodUnlock: - c.Status(http.StatusOK) - return - case wopi.MethodRename: - var service explorer.WopiService - err := service.Rename(c) - if err != nil { - c.Status(http.StatusInternalServerError) - c.Header(wopi.ServerErrorHeader, err.Error()) + case wopi.MethodLock: + err = service.Lock(c) + if err == nil { + return + } + case wopi.MethodRefreshLock: + err = service.RefreshLock(c) + if err == nil { + return + } + case wopi.MethodUnlock: + err = service.Unlock(c) + if err == nil { return } default: c.Status(http.StatusNotImplemented) return } + + if err != nil { + c.Status(http.StatusInternalServerError) + c.Header(wopi.ServerErrorHeader, err.Error()) + return + } } diff --git a/routers/controllers/workflow.go b/routers/controllers/workflow.go new file mode 100644 index 00000000..f2a8a3a2 --- /dev/null +++ b/routers/controllers/workflow.go @@ -0,0 +1,68 @@ +package controllers + +import ( + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/explorer" + "github.com/gin-gonic/gin" +) + +func ListTasks(c *gin.Context) { + service := ParametersFromContext[*explorer.ListTaskService](c, explorer.ListTaskParamCtx{}) + resp, err := service.ListTasks(c) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) + } +} + +func GetTaskPhaseProgress(c *gin.Context) { + taskId := hashid.FromContext(c) + resp, err := explorer.TaskPhaseProgress(c, taskId) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + if resp != nil { + c.JSON(200, serializer.Response{ + Data: resp, + }) + } else { + c.JSON(200, serializer.Response{Data: queue.Progresses{}}) + } +} + +func SetDownloadTaskTarget(c *gin.Context) { + taskId := hashid.FromContext(c) + service := ParametersFromContext[*explorer.SetDownloadFilesService](c, explorer.SetDownloadFilesParamCtx{}) + err := service.SetDownloadFiles(c, taskId) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) +} + +func CancelDownloadTask(c *gin.Context) { + taskId := hashid.FromContext(c) + err := explorer.CancelDownloadTask(c, taskId) + if err != nil { + c.JSON(200, serializer.Err(c, err)) + c.Abort() + return + } + + c.JSON(200, serializer.Response{}) +} diff --git a/routers/main_test.go b/routers/main_test.go deleted file mode 100644 index 83664bdb..00000000 --- a/routers/main_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package routers - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" -) - -var mock sqlmock.Sqlmock -var memDB *gorm.DB -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - // 设置gin为测试模式 - gin.SetMode(gin.TestMode) - - // 初始化sqlmock - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - - // 初始话内存数据库 - model.Init() - memDB = model.DB - - mockDB, _ = gorm.Open("mysql", db) - model.DB = memDB - defer db.Close() - - m.Run() -} - -func switchToMemDB() { - model.DB = memDB -} - -func switchToMockDB() { - model.DB = mockDB -} diff --git a/routers/router.go b/routers/router.go index aa8e9033..0389fb4c 100644 --- a/routers/router.go +++ b/routers/router.go @@ -1,147 +1,238 @@ package routers import ( - "github.com/cloudreve/Cloudreve/v3/middleware" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - wopi2 "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/cloudreve/Cloudreve/v3/routers/controllers" + "net/http" + + "github.com/abslant/gzip" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/middleware" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/conf" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/webdav" + "github.com/cloudreve/Cloudreve/v4/routers/controllers" + adminsvc "github.com/cloudreve/Cloudreve/v4/service/admin" + "github.com/cloudreve/Cloudreve/v4/service/basic" + "github.com/cloudreve/Cloudreve/v4/service/explorer" + "github.com/cloudreve/Cloudreve/v4/service/node" + "github.com/cloudreve/Cloudreve/v4/service/setting" + sharesvc "github.com/cloudreve/Cloudreve/v4/service/share" + usersvc "github.com/cloudreve/Cloudreve/v4/service/user" "github.com/gin-contrib/cors" - "github.com/gin-contrib/gzip" "github.com/gin-gonic/gin" ) // InitRouter 初始化路由 -func InitRouter() *gin.Engine { - if conf.SystemConfig.Mode == "master" { - util.Log().Info("Current running mode: Master.") - return InitMasterRouter() +func InitRouter(dep dependency.Dep) *gin.Engine { + l := dep.Logger() + if dep.ConfigProvider().System().Mode == conf.MasterMode { + l.Info("Current running mode: Master.") + return initMasterRouter(dep) } - util.Log().Info("Current running mode: Slave.") - return InitSlaveRouter() + l.Info("Current running mode: Slave.") + return initSlaveRouter(dep) + +} + +func newGinEngine(dep dependency.Dep) *gin.Engine { + r := gin.New() + r.ContextWithFallback = true + r.Use(gin.Recovery()) + r.Use(middleware.InitializeHandling(dep)) + if dep.ConfigProvider().System().Mode == conf.SlaveMode { + r.Use(middleware.InitializeHandlingSlave()) + } + r.Use(middleware.Logging()) + return r +} + +func initSlaveFileRouter(v4 *gin.RouterGroup) { + // Upload related, no signature required under this router group + upload := v4.Group("upload") + { + // 上传分片 + upload.POST(":sessionId", + controllers.FromUri[explorer.UploadService](explorer.UploadParameterCtx{}), + controllers.SlaveUpload, + ) + // 创建上传会话上传 + upload.PUT("", + controllers.FromJSON[explorer.SlaveCreateUploadSessionService](explorer.SlaveCreateUploadSessionParamCtx{}), + controllers.SlaveGetUploadSession, + ) + // 删除上传会话 + upload.DELETE(":sessionId", + controllers.FromUri[explorer.SlaveDeleteUploadSessionService](explorer.SlaveDeleteUploadSessionParamCtx{}), + controllers.SlaveDeleteUploadSession) + } + file := v4.Group("file") + { + // Get entity content for preview/download + file.GET("content/:nodeId/:src/:speed/:name", + middleware.Sandbox(), + controllers.FromUri[explorer.EntityDownloadService](explorer.EntityDownloadParameterCtx{}), + controllers.SlaveServeEntity, + ) + file.HEAD("content/:nodeId/:src/:speed/:name", + controllers.FromUri[explorer.EntityDownloadService](explorer.EntityDownloadParameterCtx{}), + controllers.SlaveServeEntity, + ) + // Get media metadata + file.GET("meta/:src/:ext", + controllers.FromUri[explorer.SlaveMetaService](explorer.SlaveMetaParamCtx{}), + controllers.SlaveMeta, + ) + // Get thumbnail + file.GET("thumb/:src/:ext", + controllers.FromUri[explorer.SlaveThumbService](explorer.SlaveThumbParamCtx{}), + controllers.SlaveThumb, + ) + // 删除文件 + file.DELETE("", + controllers.FromJSON[explorer.SlaveDeleteFileService](explorer.SlaveDeleteFileParamCtx{}), + controllers.SlaveDelete) + } } -// InitSlaveRouter 初始化从机模式路由 -func InitSlaveRouter() *gin.Engine { - r := gin.Default() +// initSlaveRouter 初始化从机模式路由 +func initSlaveRouter(dep dependency.Dep) *gin.Engine { + r := newGinEngine(dep) // 跨域相关 - InitCORS(r) - v3 := r.Group("/api/v3/slave") + initCORS(dep.Logger(), dep.ConfigProvider(), r) + v4 := r.Group(constants.APIPrefix + "/slave") // 鉴权中间件 - v3.Use(middleware.SignRequired(auth.General)) - // 主机信息解析 - v3.Use(middleware.MasterMetadata()) + v4.Use(middleware.SignRequired(dep.GeneralAuth())) // 禁止缓存 - v3.Use(middleware.CacheControl()) + v4.Use(middleware.CacheControl()) /* 路由 */ { // Ping - v3.POST("ping", controllers.SlavePing) - // 测试 Aria2 RPC 连接 - v3.POST("ping/aria2", controllers.AdminTestAria2) - // 接收主机心跳包 - v3.POST("heartbeat", controllers.SlaveHeartbeat) - // 上传 - upload := v3.Group("upload") - { - // 上传分片 - upload.POST(":sessionId", controllers.SlaveUpload) - // 创建上传会话上传 - upload.PUT("", controllers.SlaveGetUploadSession) - // 删除上传会话 - upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession) - } - // 下载 - v3.GET("download/:speed/:path/:name", controllers.SlaveDownload) - // 预览 / 外链 - v3.GET("source/:speed/:path/:name", controllers.SlavePreview) - // 缩略图 - v3.GET("thumb/:path/:ext", controllers.SlaveThumb) - // 删除文件 - v3.POST("delete", controllers.SlaveDelete) + v4.POST("ping", + controllers.FromJSON[adminsvc.SlavePingService](adminsvc.SlavePingParameterCtx{}), + controllers.SlavePing, + ) + // // 测试 Aria2 RPC 连接 + // v4.POST("ping/aria2", controllers.AdminTestAria2) + initSlaveFileRouter(v4) + // 列出文件 - v3.POST("list", controllers.SlaveList) + v4.POST("list", controllers.SlaveList) // 离线下载 - aria2 := v3.Group("aria2") - aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController)) + download := v4.Group("download") { // 创建离线下载任务 - aria2.POST("task", controllers.SlaveAria2Create) + download.POST("task", + controllers.FromJSON[slave.CreateSlaveDownload](node.CreateSlaveDownloadTaskParamCtx{}), + middleware.PrepareSlaveDownloader(dep, node.CreateSlaveDownloadTaskParamCtx{}), + controllers.SlaveDownloadTaskCreate) // 获取任务状态 - aria2.POST("status", controllers.SlaveAria2Status) + download.POST("status", + controllers.FromJSON[slave.GetSlaveDownload](node.GetSlaveDownloadTaskParamCtx{}), + middleware.PrepareSlaveDownloader(dep, node.GetSlaveDownloadTaskParamCtx{}), + controllers.SlaveDownloadTaskStatus) // 取消离线下载任务 - aria2.POST("cancel", controllers.SlaveCancelAria2Task) + download.POST("cancel", + controllers.FromJSON[slave.CancelSlaveDownload](node.CancelSlaveDownloadTaskParamCtx{}), + middleware.PrepareSlaveDownloader(dep, node.CancelSlaveDownloadTaskParamCtx{}), + controllers.SlaveCancelDownloadTask) // 选取任务文件 - aria2.POST("select", controllers.SlaveSelectTask) - // 删除任务临时文件 - aria2.POST("delete", controllers.SlaveDeleteTempFile) + download.POST("select", + controllers.FromJSON[slave.SetSlaveFilesToDownload](node.SelectSlaveDownloadFilesParamCtx{}), + middleware.PrepareSlaveDownloader(dep, node.SelectSlaveDownloadFilesParamCtx{}), + controllers.SlaveSelectFilesToDownload) + // 测试下载器连接 + download.POST("test", + controllers.FromJSON[slave.TestSlaveDownload](node.TestSlaveDownloadParamCtx{}), + middleware.PrepareSlaveDownloader(dep, node.TestSlaveDownloadParamCtx{}), + controllers.SlaveTestDownloader, + ) } // 异步任务 - task := v3.Group("task") + task := v4.Group("task") { - task.PUT("transfer", controllers.SlaveCreateTransferTask) + task.PUT("", + controllers.FromJSON[cluster.CreateSlaveTask](node.CreateSlaveTaskParamCtx{}), + controllers.SlaveCreateTask) + task.GET(":id", + controllers.FromUri[node.GetSlaveTaskService](node.GetSlaveTaskParamCtx{}), + controllers.SlaveGetTask) + task.POST("cleanup", + controllers.FromJSON[cluster.FolderCleanup](node.FolderCleanupParamCtx{}), + controllers.SlaveCleanupFolder) } } return r } -// InitCORS 初始化跨域配置 -func InitCORS(router *gin.Engine) { - if conf.CORSConfig.AllowOrigins[0] != "UNSET" { +// initCORS 初始化跨域配置 +func initCORS(l logging.Logger, config conf.ConfigProvider, router *gin.Engine) { + c := config.Cors() + if c.AllowOrigins[0] != "UNSET" { router.Use(cors.New(cors.Config{ - AllowOrigins: conf.CORSConfig.AllowOrigins, - AllowMethods: conf.CORSConfig.AllowMethods, - AllowHeaders: conf.CORSConfig.AllowHeaders, - AllowCredentials: conf.CORSConfig.AllowCredentials, - ExposeHeaders: conf.CORSConfig.ExposeHeaders, + AllowOrigins: c.AllowOrigins, + AllowMethods: c.AllowMethods, + AllowHeaders: c.AllowHeaders, + AllowCredentials: c.AllowCredentials, + ExposeHeaders: c.ExposeHeaders, })) return } // slave模式下未启动跨域的警告 - if conf.SystemConfig.Mode == "slave" { - util.Log().Warning("You are running Cloudreve as slave node, if you are using slave storage policy, please enable CORS feature in config file, otherwise file cannot be uploaded from Master site.") + if config.System().Mode == conf.SlaveMode { + l.Warning("You are running Cloudreve as slave node, if you are using slave storage policy, please enable CORS feature in config file, otherwise file cannot be uploaded from Master basic.") } } -// InitMasterRouter 初始化主机模式路由 -func InitMasterRouter() *gin.Engine { - r := gin.Default() +// initMasterRouter 初始化主机模式路由 +func initMasterRouter(dep dependency.Dep) *gin.Engine { + r := newGinEngine(dep) + // 跨域相关 + initCORS(dep.Logger(), dep.ConfigProvider(), r) // Done /* 静态资源 */ - r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{"/api/"}))) - r.Use(middleware.FrontendFileHandler()) - r.GET("manifest.json", controllers.Manifest) + r.Use(gzip.GzipHandler()) // Done + r.Use(middleware.FrontendFileHandler(dep)) // Done + r.GET("manifest.json", controllers.Manifest) // Done - v3 := r.Group("/api/v3") + noAuth := r.Group(constants.APIPrefix) + wopi := noAuth.Group("file/wopi", middleware.HashID(hashid.FileID), middleware.ViewerSessionValidation()) + { + // 获取文件信息 + wopi.GET(":id", controllers.CheckFileInfo) + // 获取文件内容 + wopi.GET(":id/contents", controllers.GetFile) + // 更新文件内容 + wopi.POST(":id/contents", controllers.PutFile) + // 通用文件操作 + wopi.POST(":id", controllers.ModifyFile) + } + + v4 := r.Group(constants.APIPrefix) /* 中间件 */ - v3.Use(middleware.Session(conf.SystemConfig.SessionSecret)) - // 跨域相关 - InitCORS(r) - // 测试模式加入Mock助手中间件 - if gin.Mode() == gin.TestMode { - v3.Use(middleware.MockHelper()) - } + v4.Use(middleware.Session(dep)) // Done + // 用户会话 - v3.Use(middleware.CurrentUser()) + v4.Use(middleware.CurrentUser()) // 禁止缓存 - v3.Use(middleware.CacheControl()) + v4.Use(middleware.CacheControl()) // Done /* 路由 @@ -152,88 +243,149 @@ func InitMasterRouter() *gin.Engine { { source.GET(":id/:name", middleware.HashID(hashid.SourceLinkID), - middleware.ValidateSourceLink(), controllers.AnonymousPermLink) } + shareShort := r.Group("s") + { + shareShort.GET(":id", + controllers.FromUri[sharesvc.ShortLinkRedirectService](sharesvc.ShortLinkRedirectParamCtx{}), + controllers.ShareRedirect, + ) + shareShort.GET(":id/:password", + controllers.FromUri[sharesvc.ShortLinkRedirectService](sharesvc.ShortLinkRedirectParamCtx{}), + controllers.ShareRedirect, + ) + } + // 全局设置相关 - site := v3.Group("site") + site := v4.Group("site") { // 测试用路由 site.GET("ping", controllers.Ping) // 验证码 site.GET("captcha", controllers.Captcha) // 站点全局配置 - site.GET("config", middleware.CSRFInit(), controllers.SiteConfig) + site.GET("config/:section", + controllers.FromUri[basic.GetSettingService](basic.GetSettingParamCtx{}), + controllers.SiteConfig, + ) + } + + // User authentication + session := v4.Group("session") + { + token := session.Group("token") + // Token based authentication + { + // 用户登录 + token.POST("", + middleware.CaptchaRequired(func(c *gin.Context) bool { + return dep.SettingProvider().LoginCaptchaEnabled(c) + }), + controllers.FromJSON[usersvc.UserLoginService](usersvc.LoginParameterCtx{}), + controllers.UserLoginValidation, + controllers.UserIssueToken, + ) + // 2-factor authentication + token.POST("2fa", + controllers.FromJSON[usersvc.OtpValidationService](usersvc.OtpValidationParameterCtx{}), + controllers.UserLogin2FAValidation, + controllers.UserIssueToken, + ) + token.POST("refresh", + controllers.FromJSON[usersvc.RefreshTokenService](usersvc.RefreshTokenParameterCtx{}), + controllers.UserRefreshToken, + ) + } + + // Prepare login + session.GET("prepare", + controllers.FromQuery[usersvc.PrepareLoginService](usersvc.PrepareLoginParameterCtx{}), + controllers.UserPrepareLogin, + ) + + authn := session.Group("authn") + { + // WebAuthn login prepare + authn.PUT("", + middleware.IsFunctionEnabled(func(c *gin.Context) bool { + return dep.SettingProvider().AuthnEnabled(c) + }), + controllers.StartLoginAuthn, + ) + // WebAuthn finish login + authn.POST("", + middleware.IsFunctionEnabled(func(c *gin.Context) bool { + return dep.SettingProvider().AuthnEnabled(c) + }), + controllers.FromJSON[usersvc.FinishPasskeyLoginService](usersvc.FinishPasskeyLoginParameterCtx{}), + controllers.FinishLoginAuthn, + controllers.UserIssueToken, + ) + } } // 用户相关路由 - user := v3.Group("user") + user := v4.Group("user") { - // 用户登录 - user.POST("session", middleware.CaptchaRequired("login_captcha"), controllers.UserLogin) - // 用户注册 + // 用户注册 Done user.POST("", - middleware.IsFunctionEnabled("register_enabled"), - middleware.CaptchaRequired("reg_captcha"), + middleware.IsFunctionEnabled(func(c *gin.Context) bool { + return dep.SettingProvider().RegisterEnabled(c) + }), + middleware.CaptchaRequired(func(c *gin.Context) bool { + return dep.SettingProvider().RegCaptchaEnabled(c) + }), + controllers.FromJSON[usersvc.UserRegisterService](usersvc.RegisterParameterCtx{}), controllers.UserRegister, ) - // 用二步验证户登录 - user.POST("2fa", controllers.User2FALogin) - // 发送密码重设邮件 - user.POST("reset", middleware.CaptchaRequired("forget_captcha"), controllers.UserSendReset) // 通过邮件里的链接重设密码 - user.PATCH("reset", controllers.UserReset) - // 邮件激活 - user.GET("activate/:id", - middleware.SignRequired(auth.General), + user.PATCH("reset/:id", middleware.HashID(hashid.UserID), - controllers.UserActivate, - ) - // WebAuthn登陆初始化 - user.GET("authn/:username", - middleware.IsFunctionEnabled("authn_enabled"), - controllers.StartLoginAuthn, + controllers.FromJSON[usersvc.UserResetService](usersvc.UserResetParameterCtx{}), + controllers.UserReset, ) - // WebAuthn登陆 - user.POST("authn/finish/:username", - middleware.IsFunctionEnabled("authn_enabled"), - controllers.FinishLoginAuthn, + // 发送密码重设邮件 + user.POST("reset", + middleware.CaptchaRequired(func(c *gin.Context) bool { + return dep.SettingProvider().ForgotPasswordCaptchaEnabled(c) + }), + controllers.FromJSON[usersvc.UserResetEmailService](usersvc.UserResetEmailParameterCtx{}), + controllers.UserSendReset, ) - // 获取用户主页展示用分享 - user.GET("profile/:id", + // 邮件激活 Done + user.GET("activate/:id", + middleware.SignRequired(dep.GeneralAuth()), middleware.HashID(hashid.UserID), - controllers.GetUserShare, + controllers.UserActivate, ) // 获取用户头像 - user.GET("avatar/:id/:size", + user.GET("avatar/:id", middleware.HashID(hashid.UserID), - middleware.StaticResourceCache(), + controllers.FromQuery[usersvc.GetAvatarService](usersvc.GetAvatarServiceParamsCtx{}), controllers.GetUserAvatar, ) + // User info + user.GET("info/:id", middleware.HashID(hashid.UserID), controllers.UserGet) + // List user shares + user.GET("shares/:id", + middleware.HashID(hashid.UserID), + controllers.FromQuery[sharesvc.ListShareService](sharesvc.ListShareParamCtx{}), + controllers.ListPublicShare, + ) } // 需要携带签名验证的 - sign := v3.Group("") - sign.Use(middleware.SignRequired(auth.General)) + sign := v4.Group("") + sign.Use(middleware.SignRequired(dep.GeneralAuth())) { file := sign.Group("file") { - // 文件外链(直接输出文件数据) - file.GET("get/:id/:name", - middleware.Sandbox(), - middleware.StaticResourceCache(), - controllers.AnonymousGetContent, - ) - // 文件外链(301跳转) - file.GET("source/:id/:name", controllers.AnonymousPermLinkDeprecated) - // 下载文件 - file.GET("download/:id", - middleware.StaticResourceCache(), - controllers.Download, + file.GET("archive/:sessionID/archive.zip", + controllers.FromUri[explorer.ArchiveService](explorer.ArchiveParamCtx{}), + controllers.DownloadArchive, ) - // 打包并下载文件 - file.GET("archive/:sessionID/archive.zip", controllers.DownloadArchive) } // Copy user session @@ -244,70 +396,70 @@ func InitMasterRouter() *gin.Engine { ) } - // 从机的 RPC 通信 - slave := v3.Group("slave") - slave.Use(middleware.SlaveRPCSignRequired(cluster.Default)) + // Receive calls from slave node + slave := v4.Group("slave") + slave.Use( + middleware.SlaveRPCSignRequired(), + ) { - // 事件通知 - slave.PUT("notification/:subject", controllers.SlaveNotificationPush) - // 上传 - upload := slave.Group("upload") + initSlaveFileRouter(slave) + // Get credential + slave.GET("credential/:id", + controllers.FromUri[node.OauthCredentialService](node.OauthCredentialParamCtx{}), + controllers.SlaveGetCredential) + statelessUpload := slave.Group("statelessUpload") { - // 上传分片 - upload.POST(":sessionId", controllers.SlaveUpload) - // 创建上传会话上传 - upload.PUT("", controllers.SlaveGetUploadSession) - // 删除上传会话 - upload.DELETE(":sessionId", controllers.SlaveDeleteUploadSession) + // Prepare upload + statelessUpload.PUT("prepare", + controllers.FromJSON[fs.StatelessPrepareUploadService](node.StatelessPrepareUploadParamCtx{}), + controllers.StatelessPrepareUpload) + // Complete upload + statelessUpload.POST("complete", + controllers.FromJSON[fs.StatelessCompleteUploadService](node.StatelessCompleteUploadParamCtx{}), + controllers.StatelessCompleteUpload) + // On upload failed + statelessUpload.POST("failed", + controllers.FromJSON[fs.StatelessOnUploadFailedService](node.StatelessOnUploadFailedParamCtx{}), + controllers.StatelessOnUploadFailed) + // Create file + statelessUpload.POST("create", + controllers.FromJSON[fs.StatelessCreateFileService](node.StatelessCreateFileParamCtx{}), + controllers.StatelessCreateFile) } - // Oauth 存储策略凭证 - slave.GET("credential/:id", controllers.SlaveGetOauthCredential) } // 回调接口 - callback := v3.Group("callback") + callback := v4.Group("callback") { // 远程策略上传回调 callback.POST( "remote/:sessionID/:key", - middleware.UseUploadSession("remote"), + middleware.UseUploadSession(types.PolicyTypeRemote), middleware.RemoteCallbackAuth(), - controllers.RemoteCallback, + controllers.ProcessCallback(http.StatusOK, false), ) - // 七牛策略上传回调 + // OSS callback callback.POST( - "qiniu/:sessionID", - middleware.UseUploadSession("qiniu"), - middleware.QiniuCallbackAuth(), - controllers.QiniuCallback, - ) - // 阿里云OSS策略上传回调 - callback.POST( - "oss/:sessionID", - middleware.UseUploadSession("oss"), + "oss/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeOss), middleware.OSSCallbackAuth(), - controllers.OSSCallback, + controllers.OSSCallbackValidate, + controllers.ProcessCallback(http.StatusBadRequest, false), ) // 又拍云策略上传回调 callback.POST( - "upyun/:sessionID", - middleware.UseUploadSession("upyun"), - middleware.UpyunCallbackAuth(), - controllers.UpyunCallback, + "upyun/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeUpyun), + controllers.UpyunCallbackAuth, + controllers.ProcessCallback(http.StatusBadRequest, false), ) onedrive := callback.Group("onedrive") { // 文件上传完成 onedrive.POST( - "finish/:sessionID", - middleware.UseUploadSession("onedrive"), - middleware.OneDriveCallbackAuth(), - controllers.OneDriveCallback, - ) - // OAuth 完成 - onedrive.GET( - "auth", - controllers.OneDriveOAuth, + ":sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeOd), + controllers.ProcessCallback(http.StatusOK, false), ) } // Google Drive related @@ -321,244 +473,568 @@ func InitMasterRouter() *gin.Engine { } // 腾讯云COS策略上传回调 callback.GET( - "cos/:sessionID", - middleware.UseUploadSession("cos"), - controllers.COSCallback, + "cos/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeCos), + controllers.ProcessCallback(http.StatusBadRequest, false), ) // AWS S3策略上传回调 callback.GET( - "s3/:sessionID", - middleware.UseUploadSession("s3"), - controllers.S3Callback, + "s3/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeS3), + controllers.ProcessCallback(http.StatusBadRequest, false), + ) + // Huawei OBS upload callback + callback.POST( + "obs/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeObs), + controllers.ProcessCallback(http.StatusBadRequest, false), + ) + // Qiniu callback + callback.POST( + "qiniu/:sessionID/:key", + middleware.UseUploadSession(types.PolicyTypeQiniu), + controllers.QiniuCallbackValidate, + controllers.ProcessCallback(http.StatusBadRequest, true), ) } - // 分享相关 - share := v3.Group("share", middleware.ShareAvailable()) + // Workflows + wf := v4.Group("workflow") + wf.Use(middleware.LoginRequired()) + { + // List + wf.GET("", + controllers.FromQuery[explorer.ListTaskService](explorer.ListTaskParamCtx{}), + controllers.ListTasks, + ) + // GetTaskProgress + wf.GET("progress/:id", + middleware.HashID(hashid.TaskID), + controllers.GetTaskPhaseProgress, + ) + // Create task to create an archive file + wf.POST("archive", + controllers.FromJSON[explorer.ArchiveWorkflowService](explorer.CreateArchiveParamCtx{}), + controllers.CreateArchive, + ) + // Create task to extract an archive file + wf.POST("extract", + controllers.FromJSON[explorer.ArchiveWorkflowService](explorer.CreateArchiveParamCtx{}), + controllers.ExtractArchive, + ) + + remoteDownload := wf.Group("download") + { + // Create task to download a file + remoteDownload.POST("", + controllers.FromJSON[explorer.DownloadWorkflowService](explorer.CreateDownloadParamCtx{}), + controllers.CreateRemoteDownload, + ) + // Set download target + remoteDownload.PATCH(":id", + middleware.HashID(hashid.TaskID), + controllers.FromJSON[explorer.SetDownloadFilesService](explorer.SetDownloadFilesParamCtx{}), + controllers.SetDownloadTaskTarget, + ) + remoteDownload.DELETE(":id", + middleware.HashID(hashid.TaskID), + controllers.CancelDownloadTask, + ) + } + } + + // 文件 + file := v4.Group("file") { - // 获取分享 - share.GET("info/:id", controllers.GetShare) - // 创建文件下载会话 - share.PUT("download/:id", - middleware.CheckShareUnlocked(), - middleware.BeforeShareDownload(), - controllers.GetShareDownload, - ) - // 预览分享文件 - share.GET("preview/:id", - middleware.CSRFCheck(), - middleware.CheckShareUnlocked(), - middleware.ShareCanPreview(), - middleware.BeforeShareDownload(), - controllers.PreviewShare, - ) - // 取得Office文档预览地址 - share.GET("doc/:id", - middleware.CheckShareUnlocked(), - middleware.ShareCanPreview(), - middleware.BeforeShareDownload(), - controllers.GetShareDocPreview, - ) - // 获取文本文件内容 - share.GET("content/:id", - middleware.CheckShareUnlocked(), - middleware.BeforeShareDownload(), - controllers.PreviewShareText, - ) - // 分享目录列文件 - share.GET("list/:id/*path", - middleware.CheckShareUnlocked(), - controllers.ListSharedFolder, - ) - // 分享目录搜索 - share.GET("search/:id/:type/:keywords", - middleware.CheckShareUnlocked(), - controllers.SearchSharedFolder, - ) - // 归档打包下载 - share.POST("archive/:id", - middleware.CheckShareUnlocked(), - middleware.BeforeShareDownload(), - controllers.ArchiveShare, - ) - // 获取README文本文件内容 - share.GET("readme/:id", - middleware.CheckShareUnlocked(), - controllers.PreviewShareReadme, + // List files + file.GET("", + controllers.FromQuery[explorer.ListFileService](explorer.ListFileParameterCtx{}), + controllers.ListDirectory, + ) + // Create file + file.POST("create", + controllers.FromJSON[explorer.CreateFileService](explorer.CreateFileParameterCtx{}), + controllers.CreateFile, + ) + // Rename file + file.POST("rename", + controllers.FromJSON[explorer.RenameFileService](explorer.RenameFileParameterCtx{}), + controllers.RenameFile, ) + // Move or copy files + file.POST("move", + controllers.FromJSON[explorer.MoveFileService](explorer.MoveFileParameterCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.MoveFileParameterCtx{}), + controllers.MoveFile) + // Get URL of the file for preview/download + file.POST("url", + middleware.ContextHint(), + controllers.FromJSON[explorer.FileURLService](explorer.FileURLParameterCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.FileURLParameterCtx{}), + controllers.FileURL, + ) + // Update file content + file.PUT("content", + controllers.FromQuery[explorer.FileUpdateService](explorer.FileUpdateParameterCtx{}), + controllers.PutContent) + // Get entity content for preview/download + content := file.Group("content") + contentCors := cors.New(cors.Config{ + AllowOrigins: []string{"*"}, + }) + content.Use(contentCors) + { + content.OPTIONS("*option", contentCors) + content.GET(":id/:speed/:name", + middleware.SignRequired(dep.GeneralAuth()), + middleware.HashID(hashid.EntityID), + middleware.Sandbox(), + controllers.FromUri[explorer.EntityDownloadService](explorer.EntityDownloadParameterCtx{}), + controllers.ServeEntity, + ) + content.HEAD(":id/:speed/:name", + middleware.SignRequired(dep.GeneralAuth()), + middleware.HashID(hashid.EntityID), + controllers.FromUri[explorer.EntityDownloadService](explorer.EntityDownloadParameterCtx{}), + controllers.ServeEntity, + ) + } // 获取缩略图 - share.GET("thumb/:id/:file", - middleware.CheckShareUnlocked(), - middleware.ShareCanPreview(), - controllers.ShareThumb, + file.GET("thumb", + middleware.ContextHint(), + controllers.FromQuery[explorer.FileThumbService](explorer.FileThumbParameterCtx{}), + controllers.Thumb, + ) + // Delete files + file.DELETE("", + controllers.FromJSON[explorer.DeleteFileService](explorer.DeleteFileParameterCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.DeleteFileParameterCtx{}), + controllers.Delete, + ) + // Force unlock + file.DELETE("lock", + controllers.FromJSON[explorer.UnlockFileService](explorer.UnlockFileParameterCtx{}), + controllers.Unlock, + ) + // Restore files + file.POST("restore", + controllers.FromJSON[explorer.DeleteFileService](explorer.DeleteFileParameterCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.DeleteFileParameterCtx{}), + controllers.Restore, + ) + // Patch metadata + file.PATCH("metadata", + controllers.FromJSON[explorer.PatchMetadataService](explorer.PatchMetadataParameterCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.PatchMetadataParameterCtx{}), + controllers.PatchMetadata, + ) + // Upload related + upload := file.Group("upload") + { + // Create upload session + upload.PUT("", + controllers.FromJSON[explorer.CreateUploadSessionService](explorer.CreateUploadSessionParameterCtx{}), + controllers.CreateUploadSession, + ) + // Upload file data + upload.POST(":sessionId/:index", + controllers.FromUri[explorer.UploadService](explorer.UploadParameterCtx{}), + controllers.FileUpload, + ) + upload.DELETE("", + controllers.FromJSON[explorer.DeleteUploadSessionService](explorer.DeleteUploadSessionParameterCtx{}), + controllers.DeleteUploadSession, + ) + } + // Pin file + pin := file.Group("pin") + { + // Pin file + pin.PUT("", + controllers.FromJSON[explorer.PinFileService](explorer.PinFileParameterCtx{}), + controllers.Pin, + ) + // Unpin file + pin.DELETE("", + controllers.FromJSON[explorer.PinFileService](explorer.PinFileParameterCtx{}), + controllers.Unpin, + ) + } + // Get file info + file.GET("info", + controllers.FromQuery[explorer.GetFileInfoService](explorer.GetFileInfoParameterCtx{}), + controllers.GetFileInfo, + ) + // Version management + version := file.Group("version") + { + // Set current version + version.POST("current", + controllers.FromJSON[explorer.SetCurrentVersionService](explorer.SetCurrentVersionParamCtx{}), + controllers.SetCurrentVersion, + ) + // Delete a version from a file + version.DELETE("", + controllers.FromJSON[explorer.DeleteVersionService](explorer.DeleteVersionParamCtx{}), + controllers.DeleteVersion, + ) + } + file.PUT("viewerSession", + controllers.FromJSON[explorer.CreateViewerSessionService](explorer.CreateViewerSessionParamCtx{}), + controllers.CreateViewerSession, ) - // 搜索公共分享 - v3.Group("share").GET("search", controllers.SearchShare) + + // 取得文件外链 + file.PUT("source", + controllers.FromJSON[explorer.GetDirectLinkService](explorer.GetDirectLinkParamCtx{}), + middleware.ValidateBatchFileCount(dep, explorer.GetDirectLinkParamCtx{}), + controllers.GetSource) } - wopi := v3.Group( - "wopi", - middleware.HashID(hashid.FileID), - middleware.WopiAccessValidation(wopi2.Default, cache.Store), - ) + // 分享相关 + share := v4.Group("share") { - // 获取文件信息 - wopi.GET("files/:id", controllers.CheckFileInfo) - // 获取文件内容 - wopi.GET("files/:id/contents", controllers.GetFile) - // 更新文件内容 - wopi.POST("files/:id/contents", middleware.WopiWriteAccess(), controllers.PutFile) - // 通用文件操作 - wopi.POST("files/:id", middleware.WopiWriteAccess(), controllers.ModifyFile) + // Create share link + share.PUT("", + middleware.LoginRequired(), + controllers.FromJSON[sharesvc.ShareCreateService](sharesvc.ShareCreateParamCtx{}), + controllers.CreateShare, + ) + // Edit existing share link + share.POST(":id", + middleware.LoginRequired(), + middleware.HashID(hashid.ShareID), + controllers.FromJSON[sharesvc.ShareCreateService](sharesvc.ShareCreateParamCtx{}), + controllers.EditShare, + ) + // Get share link info + share.GET("info/:id", + middleware.HashID(hashid.ShareID), + controllers.FromQuery[sharesvc.ShareInfoService](sharesvc.ShareInfoParamCtx{}), + controllers.GetShare, + ) + // List my shares + share.GET("", + middleware.LoginRequired(), + controllers.FromQuery[sharesvc.ListShareService](sharesvc.ListShareParamCtx{}), + controllers.ListShare, + ) + // 删除分享 + share.DELETE(":id", + middleware.LoginRequired(), + middleware.HashID(hashid.ShareID), + controllers.DeleteShare, + ) + //// 获取README文本文件内容 + //share.GET("readme/:id", + // middleware.CheckShareUnlocked(), + // controllers.PreviewShareReadme, + //) + //// 举报分享 + //share.POST("report/:id", + // middleware.CheckShareUnlocked(), + // controllers.ReportShare, + //) } // 需要登录保护的 - auth := v3.Group("") - auth.Use(middleware.AuthRequired()) + auth := v4.Group("") + auth.Use(middleware.LoginRequired()) { // 管理 admin := auth.Group("admin", middleware.IsAdmin()) { - // 获取站点概况 - admin.GET("summary", controllers.AdminSummary) - // 获取社区新闻 - admin.GET("news", controllers.AdminNews) - // 更改设置 - admin.PATCH("setting", controllers.AdminChangeSetting) - // 获取设置 - admin.POST("setting", controllers.AdminGetSetting) - // 获取用户组列表 - admin.GET("groups", controllers.AdminGetGroups) - // 重新加载子服务 - admin.GET("reload/:service", controllers.AdminReloadService) - // 测试设置 - test := admin.Group("test") + admin.GET("summary", + controllers.FromQuery[adminsvc.SummaryService](adminsvc.SummaryParamCtx{}), + controllers.AdminSummary, + ) + + settings := admin.Group("settings") { - // 测试邮件设置 - test.POST("mail", controllers.AdminSendTestMail) - // 测试缩略图生成器调用 - test.POST("thumb", controllers.AdminTestThumbGenerator) + // Get settings + settings.POST("", + controllers.FromJSON[adminsvc.GetSettingService](adminsvc.GetSettingParamCtx{}), + controllers.AdminGetSettings, + ) + // Patch settings + settings.PATCH("", + controllers.FromJSON[adminsvc.SetSettingService](adminsvc.SetSettingParamCtx{}), + controllers.AdminSetSettings, + ) } - // 离线下载相关 - aria2 := admin.Group("aria2") + // 用户组管理 + group := admin.Group("group") { - // 测试连接配置 - aria2.POST("test", controllers.AdminTestAria2) + // 列出用户组 + group.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListGroups, + ) + // 获取用户组 + group.GET(":id", + controllers.FromUri[adminsvc.SingleGroupService](adminsvc.SingleGroupParamCtx{}), + controllers.AdminGetGroup, + ) + // 创建用户组 + group.PUT("", + controllers.FromJSON[adminsvc.UpsertGroupService](adminsvc.UpsertGroupParamCtx{}), + controllers.AdminCreateGroup, + ) + // 更新用户组 + group.PUT(":id", + controllers.FromJSON[adminsvc.UpsertGroupService](adminsvc.UpsertGroupParamCtx{}), + controllers.AdminUpdateGroup, + ) + // 删除用户组 + group.DELETE(":id", + controllers.FromUri[adminsvc.SingleGroupService](adminsvc.SingleGroupParamCtx{}), + controllers.AdminDeleteGroup, + ) + } + + tool := admin.Group("tool") + { + tool.GET("wopi", + controllers.FromQuery[adminsvc.FetchWOPIDiscoveryService](adminsvc.FetchWOPIDiscoveryParamCtx{}), + controllers.AdminFetchWopi, + ) + tool.POST("thumbExecutable", + controllers.FromJSON[adminsvc.ThumbGeneratorTestService](adminsvc.ThumbGeneratorTestParamCtx{}), + controllers.AdminTestThumbGenerator) + tool.POST("mail", + controllers.FromJSON[adminsvc.TestSMTPService](adminsvc.TestSMTPParamCtx{}), + controllers.AdminSendTestMail, + ) + tool.DELETE("entityUrlCache", + controllers.AdminClearEntityUrlCache, + ) + } + + queue := admin.Group("queue") + { + queue.GET("metrics", controllers.AdminGetQueueMetrics) + // List tasks + queue.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListTasks, + ) + // Get task + queue.GET(":id", + controllers.FromUri[adminsvc.SingleTaskService](adminsvc.SingleTaskParamCtx{}), + controllers.AdminGetTask, + ) + // Batch delete task + queue.POST("batch/delete", + controllers.FromJSON[adminsvc.BatchTaskService](adminsvc.BatchTaskParamCtx{}), + controllers.AdminBatchDeleteTask, + ) + // // 列出任务 + // queue.POST("list", controllers.AdminListTask) + // // 新建文件导入任务 + // queue.POST("import", controllers.AdminCreateImportTask) } // 存储策略管理 policy := admin.Group("policy") { // 列出存储策略 - policy.POST("list", controllers.AdminListPolicy) - // 测试本地路径可用性 - policy.POST("test/path", controllers.AdminTestPath) - // 测试从机通信 - policy.POST("test/slave", controllers.AdminTestSlave) + policy.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListPolicies, + ) + // 获取存储策略详情 + policy.GET(":id", + controllers.FromUri[adminsvc.SingleStoragePolicyService](adminsvc.GetStoragePolicyParamCtx{}), + controllers.AdminGetPolicy, + ) // 创建存储策略 - policy.POST("", controllers.AdminAddPolicy) + policy.PUT("", + controllers.FromJSON[adminsvc.CreateStoragePolicyService](adminsvc.CreateStoragePolicyParamCtx{}), + controllers.AdminCreatePolicy, + ) + // 更新存储策略 + policy.PUT(":id", + controllers.FromJSON[adminsvc.UpdateStoragePolicyService](adminsvc.UpdateStoragePolicyParamCtx{}), + controllers.AdminUpdatePolicy, + ) // 创建跨域策略 - policy.POST("cors", controllers.AdminAddCORS) - // 创建COS回调函数 - policy.POST("scf", controllers.AdminAddSCF) - // 获取 OneDrive OAuth URL - oauth := policy.Group(":id/oauth") + policy.POST("cors", + controllers.FromJSON[adminsvc.CreateStoragePolicyCorsService](adminsvc.CreateStoragePolicyCorsParamCtx{}), + controllers.AdminCreateStoragePolicyCors, + ) + // // 获取 OneDrive OAuth URL + oauth := policy.Group("oauth") { // 获取 OneDrive OAuth URL - oauth.GET("onedrive", controllers.AdminOAuthURL("onedrive")) - // 获取 Google Drive OAuth URL - oauth.GET("googledrive", controllers.AdminOAuthURL("googledrive")) + oauth.POST("signin", + controllers.FromJSON[adminsvc.GetOauthRedirectService](adminsvc.GetOauthRedirectParamCtx{}), + controllers.AdminOdOAuthURL, + ) + // 获取 OAuth 回调 URL + oauth.GET("redirect", controllers.AdminGetPolicyOAuthCallbackURL) + oauth.GET("status/:id", + controllers.FromUri[adminsvc.SingleStoragePolicyService](adminsvc.GetStoragePolicyParamCtx{}), + controllers.AdminGetPolicyOAuthStatus, + ) + oauth.POST("callback", + controllers.FromJSON[adminsvc.FinishOauthCallbackService](adminsvc.FinishOauthCallbackParamCtx{}), + controllers.AdminFinishOauthCallback, + ) + oauth.GET("root/:id", + controllers.FromUri[adminsvc.SingleStoragePolicyService](adminsvc.GetStoragePolicyParamCtx{}), + controllers.AdminGetSharePointDriverRoot, + ) } - // 获取 存储策略 - policy.GET(":id", controllers.AdminGetPolicy) + // // 获取 存储策略 + // policy.GET(":id", controllers.AdminGetPolicy) // 删除 存储策略 - policy.DELETE(":id", controllers.AdminDeletePolicy) + policy.DELETE(":id", + controllers.FromUri[adminsvc.SingleStoragePolicyService](adminsvc.GetStoragePolicyParamCtx{}), + controllers.AdminDeletePolicy, + ) } - // 用户组管理 - group := admin.Group("group") + node := admin.Group("node") { - // 列出用户组 - group.POST("list", controllers.AdminListGroup) - // 获取用户组 - group.GET(":id", controllers.AdminGetGroup) - // 创建/保存用户组 - group.POST("", controllers.AdminAddGroup) - // 删除 - group.DELETE(":id", controllers.AdminDeleteGroup) + node.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListNodes, + ) + node.GET(":id", + controllers.FromUri[adminsvc.SingleNodeService](adminsvc.SingleNodeParamCtx{}), + controllers.AdminGetNode, + ) + node.POST("test", + controllers.FromJSON[adminsvc.TestNodeService](adminsvc.TestNodeParamCtx{}), + controllers.AdminTestSlave, + ) + node.POST("test/downloader", + controllers.FromJSON[adminsvc.TestNodeDownloaderService](adminsvc.TestNodeDownloaderParamCtx{}), + controllers.AdminTestDownloader, + ) + node.PUT("", + controllers.FromJSON[adminsvc.UpsertNodeService](adminsvc.UpsertNodeParamCtx{}), + controllers.AdminCreateNode, + ) + node.PUT(":id", + controllers.FromJSON[adminsvc.UpsertNodeService](adminsvc.UpsertNodeParamCtx{}), + controllers.AdminUpdateNode, + ) + node.DELETE(":id", + controllers.FromUri[adminsvc.SingleNodeService](adminsvc.SingleNodeParamCtx{}), + controllers.AdminDeleteNode, + ) } user := admin.Group("user") { // 列出用户 - user.POST("list", controllers.AdminListUser) + user.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListUsers, + ) // 获取用户 - user.GET(":id", controllers.AdminGetUser) - // 创建/保存用户 - user.POST("", controllers.AdminAddUser) - // 删除 - user.POST("delete", controllers.AdminDeleteUser) - // 封禁/解封用户 - user.PATCH("ban/:id", controllers.AdminBanUser) + user.GET(":id", + controllers.FromUri[adminsvc.SingleUserService](adminsvc.SingleUserParamCtx{}), + controllers.AdminGetUser, + ) + // 更新用户 + user.PUT(":id", + controllers.FromJSON[adminsvc.UpsertUserService](adminsvc.UpsertUserParamCtx{}), + controllers.AdminUpdateUser, + ) + // 创建用户 + user.PUT("", + controllers.FromJSON[adminsvc.UpsertUserService](adminsvc.UpsertUserParamCtx{}), + controllers.AdminCreateUser, + ) + batch := user.Group("batch") + { + // 批量删除用户 + batch.POST("delete", + controllers.FromJSON[adminsvc.BatchUserService](adminsvc.BatchUserParamCtx{}), + controllers.AdminDeleteUser, + ) + } + user.POST(":id/calibrate", + controllers.FromUri[adminsvc.SingleUserService](adminsvc.SingleUserParamCtx{}), + controllers.AdminCalibrateStorage, + ) } file := admin.Group("file") { // 列出文件 - file.POST("list", controllers.AdminListFile) - // 预览文件 - file.GET("preview/:id", middleware.Sandbox(), controllers.AdminGetFile) - // 删除 - file.POST("delete", controllers.AdminDeleteFile) - // 列出用户或外部文件系统目录 - file.GET("folders/:type/:id/*path", - controllers.AdminListFolders) + file.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListFiles, + ) + // 获取文件 + file.GET(":id", + controllers.FromUri[adminsvc.SingleFileService](adminsvc.SingleFileParamCtx{}), + controllers.AdminGetFile, + ) + // 更新文件 + file.PUT(":id", + controllers.FromJSON[adminsvc.UpsertFileService](adminsvc.UpsertFileParamCtx{}), + controllers.AdminUpdateFile, + ) + // 获取文件 URL + file.GET("url/:id", + controllers.FromUri[adminsvc.SingleFileService](adminsvc.SingleFileParamCtx{}), + controllers.AdminGetFileUrl, + ) + // 批量删除文件 + file.POST("batch/delete", + controllers.FromJSON[adminsvc.BatchFileService](adminsvc.BatchFileParamCtx{}), + controllers.AdminBatchDeleteFile, + ) } - share := admin.Group("share") - { - // 列出分享 - share.POST("list", controllers.AdminListShare) - // 删除 - share.POST("delete", controllers.AdminDeleteShare) - } - - download := admin.Group("download") + entity := admin.Group("entity") { - // 列出任务 - download.POST("list", controllers.AdminListDownload) - // 删除 - download.POST("delete", controllers.AdminDeleteDownload) + // List blobs + entity.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListEntities, + ) + // Get entity + entity.GET(":id", + controllers.FromUri[adminsvc.SingleEntityService](adminsvc.SingleEntityParamCtx{}), + controllers.AdminGetEntity, + ) + // Batch delete entity + entity.POST("batch/delete", + controllers.FromJSON[adminsvc.BatchEntityService](adminsvc.BatchEntityParamCtx{}), + controllers.AdminBatchDeleteEntity, + ) + // Get entity url + entity.GET("url/:id", + controllers.FromUri[adminsvc.SingleEntityService](adminsvc.SingleEntityParamCtx{}), + controllers.AdminGetEntityUrl, + ) } - task := admin.Group("task") - { - // 列出任务 - task.POST("list", controllers.AdminListTask) - // 删除 - task.POST("delete", controllers.AdminDeleteTask) - // 新建文件导入任务 - task.POST("import", controllers.AdminCreateImportTask) - } - - node := admin.Group("node") + share := admin.Group("share") { - // 列出从机节点 - node.POST("list", controllers.AdminListNodes) - // 列出从机节点 - node.POST("aria2/test", controllers.AdminTestAria2) - // 创建/保存节点 - node.POST("", controllers.AdminAddNode) - // 启用/暂停节点 - node.PATCH("enable/:id/:desired", controllers.AdminToggleNode) - // 删除节点 - node.DELETE(":id", controllers.AdminDeleteNode) - // 获取节点 - node.GET(":id", controllers.AdminGetNode) + // List shares + share.POST("", + controllers.FromJSON[adminsvc.AdminListService](adminsvc.AdminListServiceParamsCtx{}), + controllers.AdminListShares, + ) + // Get share + share.GET(":id", + controllers.FromUri[adminsvc.SingleShareService](adminsvc.SingleShareParamCtx{}), + controllers.AdminGetShare, + ) + // Batch delete shares + share.POST("batch/delete", + controllers.FromJSON[adminsvc.BatchShareService](adminsvc.BatchShareParamCtx{}), + controllers.AdminBatchDeleteShare, + ) } - } // 用户 @@ -567,162 +1043,93 @@ func InitMasterRouter() *gin.Engine { // 当前登录用户信息 user.GET("me", controllers.UserMe) // 存储信息 - user.GET("storage", controllers.UserStorage) + user.GET("capacity", controllers.UserStorage) + // Search user by keywords + user.GET("search", + controllers.FromQuery[usersvc.SearchUserService](usersvc.SearchUserParamCtx{}), + controllers.UserSearch, + ) // 退出登录 user.DELETE("session", controllers.UserSignOut) - // Generate temp URL for copying client-side session, used in adding accounts - // for mobile App. - user.GET("session", controllers.UserPrepareCopySession) // WebAuthn 注册相关 authn := user.Group("authn", - middleware.IsFunctionEnabled("authn_enabled")) + middleware.IsFunctionEnabled(func(c *gin.Context) bool { + return dep.SettingProvider().AuthnEnabled(c) + })) { authn.PUT("", controllers.StartRegAuthn) - authn.PUT("finish", controllers.FinishRegAuthn) + authn.POST("", + controllers.FromJSON[usersvc.FinishPasskeyRegisterService](usersvc.FinishPasskeyRegisterParameterCtx{}), + controllers.FinishRegAuthn, + ) + authn.DELETE("", + controllers.FromQuery[usersvc.DeletePasskeyService](usersvc.DeletePasskeyParameterCtx{}), + controllers.UserDeletePasskey, + ) } // 用户设置 setting := user.Group("setting") { - // 任务队列 - setting.GET("tasks", controllers.UserTasks) // 获取当前用户设定 setting.GET("", controllers.UserSetting) // 从文件上传头像 - setting.POST("avatar", controllers.UploadAvatar) - // 设定为Gravatar头像 - setting.PUT("avatar", controllers.UseGravatar) + setting.PUT("avatar", controllers.UploadAvatar) // 更改用户设定 - setting.PATCH(":option", controllers.UpdateOption) + setting.PATCH("", + controllers.FromJSON[usersvc.PatchUserSetting](usersvc.PatchUserSettingParamsCtx{}), + controllers.UpdateOption, + ) // 获得二步验证初始化信息 setting.GET("2fa", controllers.UserInit2FA) } } - // 文件 - file := auth.Group("file", middleware.HashID(hashid.FileID)) + group := auth.Group("group") { - // 上传 - upload := file.Group("upload") - { - // 文件上传 - upload.POST(":sessionId/:index", controllers.FileUpload) - // 创建上传会话 - upload.PUT("", controllers.GetUploadSession) - // 删除给定上传会话 - upload.DELETE(":sessionId", controllers.DeleteUploadSession) - // 删除全部上传会话 - upload.DELETE("", controllers.DeleteAllUploadSession) - } - // 更新文件 - file.PUT("update/:id", controllers.PutContent) - // 创建空白文件 - file.POST("create", controllers.CreateFile) - // 创建文件下载会话 - file.PUT("download/:id", controllers.CreateDownloadSession) - // 预览文件 - file.GET("preview/:id", middleware.Sandbox(), controllers.Preview) - // 获取文本文件内容 - file.GET("content/:id", middleware.Sandbox(), controllers.PreviewText) - // 取得Office文档预览地址 - file.GET("doc/:id", controllers.GetDocPreview) - // 获取缩略图 - file.GET("thumb/:id", controllers.Thumb) - // 取得文件外链 - file.POST("source", controllers.GetSource) - // 打包要下载的文件 - file.POST("archive", controllers.Archive) - // 创建文件压缩任务 - file.POST("compress", controllers.Compress) - // 创建文件解压缩任务 - file.POST("decompress", controllers.Decompress) - // 创建文件解压缩任务 - file.GET("search/:type/:keywords", controllers.SearchFile) - } - - // 离线下载任务 - aria2 := auth.Group("aria2") - { - // 创建URL下载任务 - aria2.POST("url", controllers.AddAria2URL) - // 创建种子下载任务 - aria2.POST("torrent/:id", middleware.HashID(hashid.FileID), controllers.AddAria2Torrent) - // 重新选择要下载的文件 - aria2.PUT("select/:gid", controllers.SelectAria2File) - // 取消或删除下载任务 - aria2.DELETE("task/:gid", controllers.CancelAria2Download) - // 获取正在下载中的任务 - aria2.GET("downloading", controllers.ListDownloading) - // 获取已完成的任务 - aria2.GET("finished", controllers.ListFinished) - } - - // 目录 - directory := auth.Group("directory") - { - // 创建目录 - directory.PUT("", controllers.CreateDirectory) - // 列出目录下内容 - directory.GET("*path", controllers.ListDirectory) - } - - // 对象,文件和目录的抽象 - object := auth.Group("object") - { - // 删除对象 - object.DELETE("", controllers.Delete) - // 移动对象 - object.PATCH("", controllers.Move) - // 复制对象 - object.POST("copy", controllers.Copy) - // 重命名对象 - object.POST("rename", controllers.Rename) - // 获取对象属性 - object.GET("property/:id", controllers.GetProperty) + // list all groups for options + group.GET("list", controllers.GetGroupList) } - // 分享 - share := auth.Group("share") + // WebDAV and devices + devices := auth.Group("devices") { - // 创建新分享 - share.POST("", controllers.CreateShare) - // 列出我的分享 - share.GET("", controllers.ListShare) - // 更新分享属性 - share.PATCH(":id", - middleware.ShareAvailable(), - middleware.ShareOwner(), - controllers.UpdateShare, - ) - // 删除分享 - share.DELETE(":id", - controllers.DeleteShare, - ) - } - - // 用户标签 - tag := auth.Group("tag") - { - // 创建文件分类标签 - tag.POST("filter", controllers.CreateFilterTag) - // 创建目录快捷方式标签 - tag.POST("link", controllers.CreateLinkTag) - // 删除标签 - tag.DELETE(":id", middleware.HashID(hashid.TagID), controllers.DeleteTag) - } - - // WebDAV管理相关 - webdav := auth.Group("webdav") - { - // 获取账号信息 - webdav.GET("accounts", controllers.GetWebDAVAccounts) - // 新建账号 - webdav.POST("accounts", controllers.CreateWebDAVAccounts) - // 删除账号 - webdav.DELETE("accounts/:id", controllers.DeleteWebDAVAccounts) - // 更新账号可读性和是否使用代理服务 - webdav.PATCH("accounts", controllers.UpdateWebDAVAccounts) + dav := devices.Group("dav") + { + // List WebDAV accounts + dav.GET("", + controllers.FromQuery[setting.ListDavAccountsService](setting.ListDavAccountParamCtx{}), + controllers.ListDavAccounts, + ) + // Create WebDAV account + dav.PUT("", + controllers.FromJSON[setting.CreateDavAccountService](setting.CreateDavAccountParamCtx{}), + controllers.CreateDAVAccounts, + ) + // Create WebDAV account + dav.PATCH(":id", + middleware.HashID(hashid.DavAccountID), + controllers.FromJSON[setting.CreateDavAccountService](setting.CreateDavAccountParamCtx{}), + controllers.UpdateDAVAccounts, + ) + // Delete WebDAV account + dav.DELETE(":id", + middleware.HashID(hashid.DavAccountID), + controllers.DeleteDAVAccounts, + ) + } + //// 获取账号信息 + //devices.GET("dav", controllers.GetWebDAVAccounts) + //// 删除目录挂载 + //devices.DELETE("mount/:id", + // middleware.HashID(hashid.FolderID), + // controllers.DeleteWebDAVMounts, + //) + //// 创建目录挂载 + //devices.POST("mount", controllers.CreateWebDAVMounts) + //// 更新账号可读性 + //devices.PATCH("accounts", controllers.UpdateWebDAVAccountsReadonly) } } @@ -737,18 +1144,17 @@ func InitMasterRouter() *gin.Engine { // initWebDAV 初始化WebDAV相关路由 func initWebDAV(group *gin.RouterGroup) { { - group.Use(middleware.WebDAVAuth()) - - group.Any("/*path", controllers.ServeWebDAV) - group.Any("", controllers.ServeWebDAV) - group.Handle("PROPFIND", "/*path", controllers.ServeWebDAV) - group.Handle("PROPFIND", "", controllers.ServeWebDAV) - group.Handle("MKCOL", "/*path", controllers.ServeWebDAV) - group.Handle("LOCK", "/*path", controllers.ServeWebDAV) - group.Handle("UNLOCK", "/*path", controllers.ServeWebDAV) - group.Handle("PROPPATCH", "/*path", controllers.ServeWebDAV) - group.Handle("COPY", "/*path", controllers.ServeWebDAV) - group.Handle("MOVE", "/*path", controllers.ServeWebDAV) + group.Use(middleware.CacheControl(), middleware.WebDAVAuth()) + group.Any("/*path", webdav.ServeHTTP) + group.Any("", webdav.ServeHTTP) + group.Handle("PROPFIND", "/*path", webdav.ServeHTTP) + group.Handle("PROPFIND", "", webdav.ServeHTTP) + group.Handle("MKCOL", "/*path", webdav.ServeHTTP) + group.Handle("LOCK", "/*path", webdav.ServeHTTP) + group.Handle("UNLOCK", "/*path", webdav.ServeHTTP) + group.Handle("PROPPATCH", "/*path", webdav.ServeHTTP) + group.Handle("COPY", "/*path", webdav.ServeHTTP) + group.Handle("MOVE", "/*path", webdav.ServeHTTP) } } diff --git a/routers/router_test.go b/routers/router_test.go deleted file mode 100644 index 2476de6a..00000000 --- a/routers/router_test.go +++ /dev/null @@ -1,251 +0,0 @@ -package routers - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "net/http" - "net/http/httptest" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestPing(t *testing.T) { - asserts := assert.New(t) - router := InitMasterRouter() - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/api/v3/site/ping", nil) - router.ServeHTTP(w, req) - - assert.Equal(t, 200, w.Code) - asserts.Contains(w.Body.String(), conf.BackendVersion) -} - -func TestCaptcha(t *testing.T) { - asserts := assert.New(t) - router := InitMasterRouter() - w := httptest.NewRecorder() - - req, _ := http.NewRequest( - "GET", - "/api/v3/site/captcha", - nil, - ) - - router.ServeHTTP(w, req) - - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "base64") -} - -//func TestUserSession(t *testing.T) { -// mutex.Lock() -// defer mutex.Unlock() -// switchToMockDB() -// asserts := assert.New(t) -// router := InitMasterRouter() -// w := httptest.NewRecorder() -// -// // 创建测试用验证码 -// var configD = base64Captcha.ConfigDigit{ -// Height: 80, -// Width: 240, -// MaxSkew: 0.7, -// DotCount: 80, -// CaptchaLen: 1, -// } -// idKeyD, _ := base64Captcha.GenerateCaptcha("", configD) -// middleware.ContextMock = map[string]interface{}{ -// "captchaID": idKeyD, -// } -// -// testCases := []struct { -// settingRows *sqlmock.Rows -// userRows *sqlmock.Rows -// policyRows *sqlmock.Rows -// reqBody string -// expected interface{} -// }{ -// // 登录信息正确,不需要验证码 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.BuildUserResponse(model.User{ -// Email: "admin@cloudreve.org", -// Nick: "admin", -// Policy: model.Policy{ -// Type: "local", -// OptionsSerialized: model.PolicyOption{FileType: []string{}}, -// }, -// }), -// }, -// // 登录信息正确,需要验证码,验证码错误 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "1", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.ParamErr("验证码错误", nil), -// }, -// // 邮箱正确密码错误 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.Err(401, "用户邮箱或密码错误", nil), -// }, -// //邮箱格式不正确 -// { -// reqBody: `{"userName":"admin@cloudreve","captchaCode":"captchaCode","Password":"admin123"}`, -// expected: serializer.Err(40001, "邮箱格式不正确", errors.New("Key: 'UserLoginService.UserName' Error:Field validation for 'UserName' failed on the 'email' tag")), -// }, -// // 用户被Ban -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.Baned), -// expected: serializer.Err(403, "该账号已被封禁", nil), -// }, -// // 用户未激活 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.NotActivicated), -// expected: serializer.Err(403, "该账号未激活", nil), -// }, -// } -// -// for k, testCase := range testCases { -// if testCase.settingRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.settingRows) -// } -// if testCase.userRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows) -// } -// if testCase.policyRows != nil { -// mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND \\(\\(`policies`.`id` = 1\\)\\)(.+)$").WillReturnRows(testCase.policyRows) -// } -// req, _ := http.NewRequest( -// "POST", -// "/api/v3/user/session", -// bytes.NewReader([]byte(testCase.reqBody)), -// ) -// router.ServeHTTP(w, req) -// -// asserts.Equal(200, w.Code) -// expectedJSON, _ := json.Marshal(testCase.expected) -// asserts.JSONEq(string(expectedJSON), w.Body.String(), "测试用例:%d", k) -// -// w.Body.Reset() -// asserts.NoError(mock.ExpectationsWereMet()) -// model.ClearCache() -// } -// -//} -// -//func TestSessionAuthCheck(t *testing.T) { -// mutex.Lock() -// defer mutex.Unlock() -// switchToMockDB() -// asserts := assert.New(t) -// router := InitMasterRouter() -// w := httptest.NewRecorder() -// -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}")) -// expectedUser, _ := model.GetUserByID(1) -// -// testCases := []struct { -// userRows *sqlmock.Rows -// sessionMock map[string]interface{} -// contextMock map[string]interface{} -// expected interface{} -// }{ -// // 未登录 -// { -// expected: serializer.CheckLogin(), -// }, -// // 登录正常 -// { -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// sessionMock: map[string]interface{}{"user_id": 1}, -// expected: serializer.BuildUserResponse(expectedUser), -// }, -// // UID不存在 -// { -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}), -// sessionMock: map[string]interface{}{"user_id": -1}, -// expected: serializer.CheckLogin(), -// }, -// } -// -// for _, testCase := range testCases { -// req, _ := http.NewRequest( -// "GET", -// "/api/v3/user/me", -// nil, -// ) -// if testCase.userRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows) -// } -// middleware.ContextMock = testCase.contextMock -// middleware.SessionMock = testCase.sessionMock -// router.ServeHTTP(w, req) -// expectedJSON, _ := json.Marshal(testCase.expected) -// -// asserts.Equal(200, w.Code) -// asserts.JSONEq(string(expectedJSON), w.Body.String()) -// asserts.NoError(mock.ExpectationsWereMet()) -// -// w.Body.Reset() -// } -// -//} - -func TestSiteConfigRoute(t *testing.T) { - switchToMemDB() - asserts := assert.New(t) - router := InitMasterRouter() - w := httptest.NewRecorder() - - req, _ := http.NewRequest( - "GET", - "/api/v3/site/config", - nil, - ) - router.ServeHTTP(w, req) - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "Cloudreve") - - w.Body.Reset() - - // 消除无效值 - model.DB.Model(&model.Setting{ - Model: gorm.Model{ - ID: 2, - }, - }).UpdateColumn("name", "siteName_b") - - req, _ = http.NewRequest( - "GET", - "/api/v3/site/config", - nil, - ) - router.ServeHTTP(w, req) - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "\"title\"") - - model.DB.Model(&model.Setting{ - Model: gorm.Model{ - ID: 2, - }, - }).UpdateColumn("name", "siteName") -} diff --git a/service/admin/aria2.go b/service/admin/aria2.go deleted file mode 100644 index 6a2b77de..00000000 --- a/service/admin/aria2.go +++ /dev/null @@ -1,71 +0,0 @@ -package admin - -import ( - "bytes" - "encoding/json" - model "github.com/cloudreve/Cloudreve/v3/models" - "net/url" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" -) - -// Aria2TestService aria2连接测试服务 -type Aria2TestService struct { - Server string `json:"server"` - RPC string `json:"rpc" binding:"required"` - Secret string `json:"secret"` - Token string `json:"token"` - Type model.ModelType `json:"type"` -} - -// Test 测试aria2连接 -func (service *Aria2TestService) TestMaster() serializer.Response { - res, err := aria2.TestRPCConnection(service.RPC, service.Token, 5) - if err != nil { - return serializer.ParamErr("Failed to connect to RPC server: "+err.Error(), err) - } - - if res.Version == "" { - return serializer.ParamErr("RPC server returns unexpected response", nil) - } - - return serializer.Response{Data: res.Version} -} - -func (service *Aria2TestService) TestSlave() serializer.Response { - slave, err := url.Parse(service.Server) - if err != nil { - return serializer.ParamErr("Cannot parse slave server URL, "+err.Error(), nil) - } - - controller, _ := url.Parse("/api/v3/slave/ping/aria2") - - // 请求正文 - service.Type = model.MasterNodeType - bodyByte, _ := json.Marshal(service) - - r := request.NewClient() - res, err := r.Request( - "POST", - slave.ResolveReference(controller).String(), - bytes.NewReader(bodyByte), - request.WithTimeout(time.Duration(10)*time.Second), - request.WithCredential( - auth.HMACAuth{SecretKey: []byte(service.Secret)}, - int64(model.GetIntSetting("slave_api_timeout", 60)), - ), - ).DecodeResponse() - if err != nil { - return serializer.ParamErr("Failed to connect to slave node, "+err.Error(), nil) - } - - if res.Code != 0 { - return serializer.ParamErr("Successfully connected to slave, but slave returns: "+res.Msg, nil) - } - - return serializer.Response{Data: res.Data.(string)} -} diff --git a/service/admin/file.go b/service/admin/file.go index a029989d..de08a7e5 100644 --- a/service/admin/file.go +++ b/service/admin/file.go @@ -2,14 +2,22 @@ package admin import ( "context" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/service/explorer" + "path" + "strconv" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) // FileService 文件ID服务 @@ -33,176 +41,495 @@ type ListFolderService struct { // List 列出指定路径下的目录 func (service *ListFolderService) List(c *gin.Context) serializer.Response { - if service.Type == "policy" { - // 列取存储策略中的目录 - policy, err := model.GetPolicyByID(service.ID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) - } + //if service.Type == "policy" { + // // 列取存储策略中的目录 + // policy, err := model.GetPolicyByID(service.ID) + // if err != nil { + // return serializer.ErrDeprecated(serializer.CodePolicyNotExist, "", err) + // } + // + // // 创建文件系统 + // fs, err := filesystem.NewAnonymousFileSystem() + // if err != nil { + // return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) + // } + // defer fs.Recycle() + // + // // 列取存储策略中的文件 + // fs.Policy = &policy + // res, err := fs.ListPhysical(c.Request.Context(), service.Path) + // if err != nil { + // return serializer.ErrDeprecated(serializer.CodeListFilesError, "", err) + // } + // + // return serializer.Response{ + // Data: serializer.BuildObjectList(0, res, nil), + // } + // + //} + // + //// 列取用户空间目录 + //// 查找用户 + //user, err := model.GetUserByID(service.ID) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeUserNotFound, "", err) + //} + // + //// 创建文件系统 + //fs, err := filesystem.NewFileSystem(&user) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) + //} + //defer fs.Recycle() + // + //// 列取目录 + //res, err := fs.List(c.Request.Context(), service.Path, nil) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeListFilesError, "", err) + //} + + //return serializer.Response{ + // Data: serializer.BuildObjectList(0, res, nil), + //} + + return serializer.Response{} +} + +// Delete 删除文件 +func (service *FileBatchService) Delete(c *gin.Context) serializer.Response { + //files, err := model.GetFilesByIDs(service.ID, 0) + //if err != nil { + // return serializer.DBErrDeprecated("Failed to list files for deleting", err) + //} + // + //// 根据用户分组 + //userFile := make(map[uint][]model.File) + //for i := 0; i < len(files); i++ { + // if _, ok := userFile[files[i].UserID]; !ok { + // userFile[files[i].UserID] = []model.File{} + // } + // userFile[files[i].UserID] = append(userFile[files[i].UserID], files[i]) + //} + // + //// 异步执行删除 + //go func(files map[uint][]model.File) { + // for uid, file := range files { + // var ( + // fs *filesystem.FileSystem + // err error + // ) + // user, err := model.GetUserByID(uid) + // if err != nil { + // fs, err = filesystem.NewAnonymousFileSystem() + // if err != nil { + // continue + // } + // } else { + // fs, err = filesystem.NewFileSystem(&user) + // if err != nil { + // fs.Recycle() + // continue + // } + // } + // + // // 汇总文件ID + // ids := make([]uint, 0, len(file)) + // for i := 0; i < len(file); i++ { + // ids = append(ids, file[i].ID) + // } + // + // // 执行删除 + // fs.Delete(context.Background(), []uint{}, ids, service.Force, service.UnlinkOnly) + // fs.Recycle() + // } + //}(userFile) + + // 分组执行删除 + return serializer.Response{} + +} - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() +const ( + fileNameCondition = "file_name" + fileUserCondition = "file_user" + filePolicyCondition = "file_policy" +) + +func (service *AdminListService) Files(c *gin.Context) (*ListFileResponse, error) { + dep := dependency.FromContext(c) + hasher := dep.HashIDEncoder() + fileClient := dep.FileClient() + + ctx := context.WithValue(c, inventory.LoadFileEntity{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileMetadata{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileDirectLink{}, true) + + var ( + err error + userID int + policyID int + ) + + if service.Conditions[fileUserCondition] != "" { + userID, err = strconv.Atoi(service.Conditions[fileUserCondition]) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid user ID", err) } - defer fs.Recycle() + } - // 列取存储策略中的文件 - fs.Policy = &policy - res, err := fs.ListPhysical(c.Request.Context(), service.Path) + if service.Conditions[filePolicyCondition] != "" { + policyID, err = strconv.Atoi(service.Conditions[filePolicyCondition]) if err != nil { - return serializer.Err(serializer.CodeListFilesError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid policy ID", err) } + } - return serializer.Response{ - Data: serializer.BuildObjectList(0, res, nil), - } + res, err := fileClient.FlattenListFiles(ctx, &inventory.FlattenListFileParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: service.Page - 1, + PageSize: service.PageSize, + OrderBy: service.OrderBy, + Order: inventory.OrderDirection(service.OrderDirection), + }, + UserID: userID, + StoragePolicyID: policyID, + Name: service.Conditions[fileNameCondition], + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list files", err) + } + + return &ListFileResponse{ + Pagination: res.PaginationResults, + Files: lo.Map(res.Files, func(file *ent.File, _ int) GetFileResponse { + return GetFileResponse{ + File: file, + UserHashID: hashid.EncodeUserID(hasher, file.OwnerID), + } + }), + }, nil +} +type ( + SingleFileService struct { + ID int `uri:"id" json:"id" binding:"required"` } + SingleFileParamCtx struct{} +) + +func (service *SingleFileService) Get(c *gin.Context) (*GetFileResponse, error) { + dep := dependency.FromContext(c) + hasher := dep.HashIDEncoder() + fileClient := dep.FileClient() + + ctx := context.WithValue(c, inventory.LoadFileEntity{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileMetadata{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadEntityUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadEntityStoragePolicy{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileDirectLink{}, true) - // 列取用户空间目录 - // 查找用户 - user, err := model.GetUserByID(service.ID) + file, err := fileClient.GetByID(ctx, service.ID) if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) + if ent.IsNotFound(err) { + return nil, serializer.NewError(serializer.CodeNotFound, "File not found", nil) + } + + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get file", err) + } + + directLinkMap := make(map[int]string) + siteURL := dep.SettingProvider().SiteURL(c) + for _, directLink := range file.Edges.DirectLinks { + directLinkMap[directLink.ID] = routes.MasterDirectLink(siteURL, hashid.EncodeSourceLinkID(hasher, directLink.ID), directLink.Name).String() + } + + return &GetFileResponse{ + File: file, + UserHashID: hashid.EncodeUserID(hasher, file.OwnerID), + DirectLinkMap: directLinkMap, + }, nil +} + +type ( + UpsertFileService struct { + File *ent.File `json:"file" binding:"required"` } + UpsertFileParamCtx struct{} +) - // 创建文件系统 - fs, err := filesystem.NewFileSystem(&user) +func (s *UpsertFileService) Update(c *gin.Context) (*GetFileResponse, error) { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + + fc, tx, ctx, err := inventory.WithTx(c, fileClient) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) } - defer fs.Recycle() - // 列取目录 - res, err := fs.List(c.Request.Context(), service.Path, nil) + newFile, err := fc.Update(ctx, s.File) if err != nil { - return serializer.Err(serializer.CodeListFilesError, "", err) + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update file", err) } - return serializer.Response{ - Data: serializer.BuildObjectList(0, res, nil), + if err := inventory.Commit(tx); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit transaction", err) } + + service := &SingleFileService{ID: newFile.ID} + return service.Get(c) } -// Delete 删除文件 -func (service *FileBatchService) Delete(c *gin.Context) serializer.Response { - files, err := model.GetFilesByIDs(service.ID, 0) +func (s *SingleFileService) Url(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + + ctx := context.WithValue(c, inventory.LoadFileEntity{}, true) + file, err := fileClient.GetByID(ctx, s.ID) if err != nil { - return serializer.DBErr("Failed to list files for deleting", err) + return "", serializer.NewError(serializer.CodeDBError, "Failed to get file", err) } - // 根据用户分组 - userFile := make(map[uint][]model.File) - for i := 0; i < len(files); i++ { - if _, ok := userFile[files[i].UserID]; !ok { - userFile[files[i].UserID] = []model.File{} + // find primary entity + var primaryEntity *ent.Entity + for _, entity := range file.Edges.Entities { + if entity.Type == int(types.EntityTypeVersion) && entity.ID == file.PrimaryEntity { + primaryEntity = entity + break } - userFile[files[i].UserID] = append(userFile[files[i].UserID], files[i]) - } - - // 异步执行删除 - go func(files map[uint][]model.File) { - for uid, file := range files { - var ( - fs *filesystem.FileSystem - err error - ) - user, err := model.GetUserByID(uid) - if err != nil { - fs, err = filesystem.NewAnonymousFileSystem() - if err != nil { - continue - } - } else { - fs, err = filesystem.NewFileSystem(&user) - if err != nil { - fs.Recycle() - continue - } - } + } - // 汇总文件ID - ids := make([]uint, 0, len(file)) - for i := 0; i < len(file); i++ { - ids = append(ids, file[i].ID) - } + if primaryEntity == nil { + return "", serializer.NewError(serializer.CodeNotFound, "Primary entity not exist", nil) + } - // 执行删除 - fs.Delete(context.Background(), []uint{}, ids, service.Force, service.UnlinkOnly) - fs.Recycle() - } - }(userFile) + // find policy + policy, err := dep.StoragePolicyClient().GetPolicyByID(ctx, primaryEntity.StoragePolicyEntities) + if err != nil { + return "", serializer.NewError(serializer.CodeDBError, "Failed to get policy", err) + } - // 分组执行删除 - return serializer.Response{} + m := manager.NewFileManager(dep, inventory.UserFromContext(c)) + defer m.Recycle() + driver, err := m.GetStorageDriver(ctx, policy) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to get storage driver", err) + } + + es := entitysource.NewEntitySource(fs.NewEntity(primaryEntity), driver, policy, dep.GeneralAuth(), + dep.SettingProvider(), dep.HashIDEncoder(), dep.RequestClient(), dep.Logger(), dep.ConfigProvider(), dep.MimeDetector(ctx)) + + expire := time.Now().Add(time.Hour * 1) + url, err := es.Url(ctx, entitysource.WithExpire(&expire), entitysource.WithDisplayName(file.Name)) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to get url", err) + } + + return url.Url, nil } -// Get 预览文件 -func (service *FileService) Get(c *gin.Context) serializer.Response { - file, err := model.GetFilesByIDs([]uint{service.ID}, 0) +type ( + BatchFileService struct { + IDs []int `json:"ids" binding:"min=1"` + } + BatchFileParamCtx struct{} +) + +func (s *BatchFileService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + + ctx := context.WithValue(c, inventory.LoadFileEntity{}, true) + files, _, err := fileClient.GetByIDs(ctx, s.IDs, 0) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to get files", err) + } + + fc, tx, ctx, err := inventory.WithTx(c, fileClient) if err != nil { - return serializer.Err(serializer.CodeFileNotFound, "", err) + return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err) } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, &file[0]) - var subService explorer.FileIDService - res := subService.PreviewContent(ctx, c, false) + _, diff, err := fc.Delete(ctx, files, nil) + if err != nil { + _ = inventory.Rollback(tx) + return serializer.NewError(serializer.CodeDBError, "Failed to delete files", err) + } - return res + tx.AppendStorageDiff(diff) + if err := inventory.CommitWithStorageDiff(ctx, tx, dep.Logger(), dep.UserClient()); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to commit transaction", err) + } + + return nil } -// Files 列出文件 -func (service *AdminListService) Files() serializer.Response { - var res []model.File - total := 0 +const ( + entityUserCondition = "entity_user" + entityPolicyCondition = "entity_policy" + entityTypeCondition = "entity_type" +) - tx := model.DB.Model(&model.File{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) +func (s *AdminListService) Entities(c *gin.Context) (*ListEntityResponse, error) { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + hasher := dep.HashIDEncoder() + ctx := context.WithValue(c, inventory.LoadEntityUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadEntityStoragePolicy{}, true) + + var ( + userID int + policyID int + err error + entityType *types.EntityType + ) + + if s.Conditions[entityUserCondition] != "" { + userID, err = strconv.Atoi(s.Conditions[entityUserCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid user ID", err) + } } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + if s.Conditions[entityPolicyCondition] != "" { + policyID, err = strconv.Atoi(s.Conditions[entityPolicyCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid policy ID", err) + } } - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += k + " like '%" + v + "%' OR " + if s.Conditions[entityTypeCondition] != "" { + typeId, err := strconv.Atoi(s.Conditions[entityTypeCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid entity type", err) } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + + t := types.EntityType(typeId) + entityType = &t } - // 计算总数用于分页 - tx.Count(&total) + res, err := fileClient.ListEntities(ctx, &inventory.ListEntityParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: s.Page - 1, + PageSize: s.PageSize, + OrderBy: s.OrderBy, + Order: inventory.OrderDirection(s.OrderDirection), + }, + UserID: userID, + StoragePolicyID: policyID, + EntityType: entityType, + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list entities", err) + } - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + return &ListEntityResponse{ + Pagination: res.PaginationResults, + Entities: lo.Map(res.Entities, func(entity *ent.Entity, _ int) GetEntityResponse { + return GetEntityResponse{ + Entity: entity, + UserHashID: hashid.EncodeUserID(hasher, entity.CreatedBy), + } + }), + }, nil +} - // 查询对应用户 - users := make(map[uint]model.User) - for _, file := range res { - users[file.UserID] = model.User{} +type ( + SingleEntityService struct { + ID int `uri:"id" json:"id" binding:"required"` } + SingleEntityParamCtx struct{} +) + +func (s *SingleEntityService) Get(c *gin.Context) (*GetEntityResponse, error) { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + hasher := dep.HashIDEncoder() + + ctx := context.WithValue(c, inventory.LoadEntityUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadEntityStoragePolicy{}, true) + ctx = context.WithValue(ctx, inventory.LoadEntityFile{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileUser{}, true) - userIDs := make([]uint, 0, len(users)) - for k := range users { - userIDs = append(userIDs, k) + userHashIDMap := make(map[int]string) + entity, err := fileClient.GetEntityByID(ctx, s.ID) + if err != nil { + if ent.IsNotFound(err) { + return nil, serializer.NewError(serializer.CodeNotFound, "Entity not found", nil) + } + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get entity", err) + } + + for _, file := range entity.Edges.File { + userHashIDMap[file.OwnerID] = hashid.EncodeUserID(hasher, file.OwnerID) + } + + return &GetEntityResponse{ + Entity: entity, + UserHashID: hashid.EncodeUserID(hasher, entity.CreatedBy), + UserHashIDMap: userHashIDMap, + }, nil +} + +type ( + BatchEntityService struct { + IDs []int `json:"ids" binding:"min=1"` + Force bool `json:"force"` } + BatchEntityParamCtx struct{} +) - var userList []model.User - model.DB.Where("id in (?)", userIDs).Find(&userList) +func (s *BatchEntityService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, inventory.UserFromContext(c)) + defer m.Recycle() + + err := m.RecycleEntities(c.Request.Context(), s.Force, s.IDs...) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to recycle entities", err) + } + + return nil +} - for _, v := range userList { - users[v.ID] = v +func (s *SingleEntityService) Url(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + fileClient := dep.FileClient() + + entity, err := fileClient.GetEntityByID(c, s.ID) + if err != nil { + return "", serializer.NewError(serializer.CodeDBError, "Failed to get file", err) + } + + // find policy + policy, err := dep.StoragePolicyClient().GetPolicyByID(c, entity.StoragePolicyEntities) + if err != nil { + return "", serializer.NewError(serializer.CodeDBError, "Failed to get policy", err) + } + + m := manager.NewFileManager(dep, inventory.UserFromContext(c)) + defer m.Recycle() + + driver, err := m.GetStorageDriver(c, policy) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to get storage driver", err) + } + + es := entitysource.NewEntitySource(fs.NewEntity(entity), driver, policy, dep.GeneralAuth(), + dep.SettingProvider(), dep.HashIDEncoder(), dep.RequestClient(), dep.Logger(), dep.ConfigProvider(), dep.MimeDetector(c)) + + expire := time.Now().Add(time.Hour * 1) + url, err := es.Url(c, entitysource.WithDownload(true), entitysource.WithExpire(&expire), entitysource.WithDisplayName(path.Base(entity.Source))) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to get url", err) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "users": users, - }} + return url.Url, nil } diff --git a/service/admin/group.go b/service/admin/group.go index 272ac992..e2cddf84 100644 --- a/service/admin/group.go +++ b/service/admin/group.go @@ -1,14 +1,20 @@ package admin import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "context" "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" ) // AddGroupService 用户组添加服务 type AddGroupService struct { - Group model.Group `json:"group" binding:"required"` + //Group model.Group `json:"group" binding:"required"` } // GroupService 用户组ID服务 @@ -18,100 +24,169 @@ type GroupService struct { // Get 获取用户组详情 func (service *GroupService) Get() serializer.Response { - group, err := model.GetGroupByID(service.ID) - if err != nil { - return serializer.Err(serializer.CodeGroupNotFound, "", err) - } + //group, err := model.GetGroupByID(service.ID) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeGroupNotFound, "", err) + //} + // + //return serializer.Response{Data: group} - return serializer.Response{Data: group} + return serializer.Response{} } // Delete 删除用户组 func (service *GroupService) Delete() serializer.Response { - // 查找用户组 - group, err := model.GetGroupByID(service.ID) + //// 查找用户组 + //group, err := model.GetGroupByID(service.ID) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeGroupNotFound, "", err) + //} + // + //// 是否为系统用户组 + //if group.ID <= 3 { + // return serializer.ErrDeprecated(serializer.CodeInvalidActionOnSystemGroup, "", err) + //} + // + //// 检查是否有用户使用 + //total := 0 + //row := model.DB.Model(&model.User{}).Where("group_id = ?", service.ID). + // Select("count(id)").Row() + //row.Scan(&total) + //if total > 0 { + // return serializer.ErrDeprecated(serializer.CodeGroupUsedByUser, strconv.Itoa(total), nil) + //} + // + //model.DB.Delete(&group) + + return serializer.Response{} +} + +func (service *SingleGroupService) Delete(c *gin.Context) error { + if service.ID <= 3 { + return serializer.NewError(serializer.CodeInvalidActionOnSystemGroup, "", nil) + } + + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() + + // Any user still under this group? + users, err := groupClient.CountUsers(c, int(service.ID)) if err != nil { - return serializer.Err(serializer.CodeGroupNotFound, "", err) + return serializer.NewError(serializer.CodeDBError, "Failed to count users", err) } - // 是否为系统用户组 - if group.ID <= 3 { - return serializer.Err(serializer.CodeInvalidActionOnSystemGroup, "", err) + if users > 0 { + return serializer.NewError(serializer.CodeGroupUsedByUser, strconv.Itoa(users), nil) } - // 检查是否有用户使用 - total := 0 - row := model.DB.Model(&model.User{}).Where("group_id = ?", service.ID). - Select("count(id)").Row() - row.Scan(&total) - if total > 0 { - return serializer.Err(serializer.CodeGroupUsedByUser, strconv.Itoa(total), nil) + err = groupClient.Delete(c, service.ID) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete group", err) } - model.DB.Delete(&group) + return nil +} - return serializer.Response{} +func (s *AdminListService) List(c *gin.Context) (*ListGroupResponse, error) { + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() + + ctx := context.WithValue(c, inventory.LoadGroupPolicy{}, true) + res, err := groupClient.ListGroups(ctx, &inventory.ListGroupParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: s.Page - 1, + PageSize: s.PageSize, + OrderBy: s.OrderBy, + Order: inventory.OrderDirection(s.OrderDirection), + }, + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list groups", err) + } + + return &ListGroupResponse{ + Pagination: res.PaginationResults, + Groups: res.Groups, + }, nil } -// Add 添加用户组 -func (service *AddGroupService) Add() serializer.Response { - if service.Group.ID > 0 { - if err := model.DB.Save(&service.Group).Error; err != nil { - return serializer.DBErr("Failed to save group record", err) - } - } else { - if err := model.DB.Create(&service.Group).Error; err != nil { - return serializer.DBErr("Failed to create group record", err) +type ( + SingleGroupService struct { + ID int `uri:"id" json:"id" binding:"required"` + } + SingleGroupParamCtx struct{} +) + +const ( + countUserQuery = "countUser" +) + +func (s *SingleGroupService) Get(c *gin.Context) (*GetGroupResponse, error) { + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() + + ctx := context.WithValue(c, inventory.LoadGroupPolicy{}, true) + group, err := groupClient.GetByID(ctx, s.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get group", err) + } + + res := &GetGroupResponse{Group: group} + + if c.Query(countUserQuery) != "" { + totalUsers, err := groupClient.CountUsers(ctx, int(s.ID)) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to count users", err) } + res.TotalUsers = totalUsers } - return serializer.Response{Data: service.Group.ID} + return res, nil } -// Groups 列出用户组 -func (service *AdminListService) Groups() serializer.Response { - var res []model.Group - total := 0 +type ( + UpsertGroupService struct { + Group *ent.Group `json:"group" binding:"required"` + } + UpsertGroupParamCtx struct{} +) + +func (s *UpsertGroupService) Update(c *gin.Context) (*GetGroupResponse, error) { + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() - tx := model.DB.Model(&model.Group{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) + if s.Group.ID == 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID is required", nil) } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + // Initial admin group have to be admin + if s.Group.ID == 1 && !s.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) { + return nil, serializer.NewError(serializer.CodeParamErr, "Initial admin group have to be admin", nil) } - // 计算总数用于分页 - tx.Count(&total) + group, err := groupClient.Upsert(c, s.Group) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update group", err) + } - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + service := &SingleGroupService{ID: group.ID} + return service.Get(c) +} + +func (s *UpsertGroupService) Create(c *gin.Context) (*GetGroupResponse, error) { + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() - // 统计每个用户组的用户总数 - statics := make(map[uint]int, len(res)) - for i := 0; i < len(res); i++ { - total := 0 - row := model.DB.Model(&model.User{}).Where("group_id = ?", res[i].ID). - Select("count(id)").Row() - row.Scan(&total) - statics[res[i].ID] = total + if s.Group.ID > 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID must be 0", nil) } - // 汇总用户组存储策略 - policies := make(map[uint]model.Policy) - for i := 0; i < len(res); i++ { - for _, p := range res[i].PolicyList { - if _, ok := policies[p]; !ok { - policies[p], _ = model.GetPolicyByID(p) - } - } + group, err := groupClient.Upsert(c, s.Group) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create group", err) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "statics": statics, - "policies": policies, - }} + service := &SingleGroupService{ID: group.ID} + return service.Get(c) } diff --git a/service/admin/list.go b/service/admin/list.go index bd84e357..b9aa85c2 100644 --- a/service/admin/list.go +++ b/service/admin/list.go @@ -1,22 +1,26 @@ package admin import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" ) // AdminListService 仪表盘列条目服务 -type AdminListService struct { - Page int `json:"page" binding:"min=1,required"` - PageSize int `json:"page_size" binding:"min=1,required"` - OrderBy string `json:"order_by"` - Conditions map[string]string `form:"conditions"` - Searches map[string]string `form:"searches"` -} +type ( + AdminListService struct { + Page int `json:"page" binding:"min=1"` + PageSize int `json:"page_size" binding:"min=1,required"` + OrderBy string `json:"order_by"` + OrderDirection string `json:"order_direction"` + Conditions map[string]string `json:"conditions"` + Searches map[string]string `json:"searches"` + } + AdminListServiceParamsCtx struct{} +) // GroupList 获取用户组列表 func (service *NoParamService) GroupList() serializer.Response { - var res []model.Group - model.DB.Model(&model.Group{}).Find(&res) - return serializer.Response{Data: res} + //var res []model.Group + //model.DB.Model(&model.Group{}).Find(&res) + //return serializer.Response{Data: res} + return serializer.Response{} } diff --git a/service/admin/node.go b/service/admin/node.go index c8610581..1badadac 100644 --- a/service/admin/node.go +++ b/service/admin/node.go @@ -1,142 +1,256 @@ package admin import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" "strings" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/node" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/gin-gonic/gin" + "github.com/samber/lo" +) + +const ( + nodeStatusCondition = "node_status" ) -// AddNodeService 节点添加服务 -type AddNodeService struct { - Node model.Node `json:"node" binding:"required"` +func (service *AdminListService) Nodes(c *gin.Context) (*ListNodeResponse, error) { + dep := dependency.FromContext(c) + nodeClient := dep.NodeClient() + + ctx := context.WithValue(c, inventory.LoadNodeStoragePolicy{}, true) + res, err := nodeClient.ListNodes(ctx, &inventory.ListNodeParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: service.Page - 1, + PageSize: service.PageSize, + OrderBy: service.OrderBy, + Order: inventory.OrderDirection(service.OrderDirection), + }, + Status: node.Status(service.Conditions[nodeStatusCondition]), + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list nodes", err) + } + + return &ListNodeResponse{Nodes: res.Nodes, Pagination: res.PaginationResults}, nil } -// Add 添加节点 -func (service *AddNodeService) Add() serializer.Response { - if service.Node.ID > 0 { - if err := model.DB.Save(&service.Node).Error; err != nil { - return serializer.DBErr("Failed to save node record", err) - } - } else { - if err := model.DB.Create(&service.Node).Error; err != nil { - return serializer.DBErr("Failed to create node record", err) - } +type ( + SingleNodeService struct { + ID int `uri:"id" json:"id" binding:"required"` } + SingleNodeParamCtx struct{} +) - if service.Node.Status == model.NodeActive { - cluster.Default.Add(&service.Node) +func (service *SingleNodeService) Get(c *gin.Context) (*GetNodeResponse, error) { + dep := dependency.FromContext(c) + nodeClient := dep.NodeClient() + + ctx := context.WithValue(c, inventory.LoadNodeStoragePolicy{}, true) + node, err := nodeClient.GetNodeById(ctx, service.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get node", err) } - return serializer.Response{Data: service.Node.ID} + return &GetNodeResponse{Node: node}, nil } -// Nodes 列出从机节点 -func (service *AdminListService) Nodes() serializer.Response { - var res []model.Node - total := 0 - - tx := model.DB.Model(&model.Node{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) +type ( + TestNodeService struct { + Node *ent.Node `json:"node" binding:"required"` } + TestNodeParamCtx struct{} +) + +func (service *TestNodeService) Test(c *gin.Context) error { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + slave, err := url.Parse(service.Node.Server) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "Failed to parse node URL", err) } - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += k + " like '%" + v + "%' OR " - } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + primaryURL := settings.SiteURL(setting.UseFirstSiteUrl(c)).String() + body := map[string]string{ + "callback": primaryURL, } + bodyByte, _ := json.Marshal(body) - // 计算总数用于分页 - tx.Count(&total) + r := dep.RequestClient() + res, err := r.Request( + "POST", + routes.SlavePingRoute(slave), + bytes.NewReader(bodyByte), + request.WithTimeout(time.Duration(10)*time.Second), + request.WithCredential( + auth.HMACAuth{SecretKey: []byte(service.Node.SlaveKey)}, + int64(settings.SlaveRequestSignTTL(c)), + ), + request.WithSlaveMeta(int(service.Node.ID)), + request.WithMasterMeta(settings.SiteBasic(c).ID, primaryURL), + request.WithCorrelationID(), + ).CheckHTTPResponse(http.StatusOK).DecodeResponse() - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "Failed to connect to node: "+err.Error(), nil) + } - isActive := make(map[uint]bool) - for i := 0; i < len(res); i++ { - if node := cluster.Default.GetNodeByID(res[i].ID); node != nil { - isActive[res[i].ID] = node.IsActive() - } + if res.Code != 0 { + return serializer.NewError(serializer.CodeParamErr, "Successfully connected to slave node, but slave returns: "+res.Msg, nil) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "active": isActive, - }} + return nil } -// ToggleNodeService 开关节点服务 -type ToggleNodeService struct { - ID uint `uri:"id"` - Desired model.NodeStatus `uri:"desired"` -} +type ( + TestNodeDownloaderService struct { + Node *ent.Node `json:"node" binding:"required"` + } + TestNodeDownloaderParamCtx struct{} +) + +func (service *TestNodeDownloaderService) Test(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + var ( + dl downloader.Downloader + err error + ) + if service.Node.Type == node.TypeMaster { + dl, err = cluster.NewDownloader(c, dep.RequestClient(request.WithContext(c)), dep.SettingProvider(), service.Node.Settings) + } else { + dl = slave.NewSlaveDownloader(dep.RequestClient( + request.WithContext(c), + request.WithCorrelationID(), + request.WithSlaveMeta(service.Node.ID), + request.WithMasterMeta(settings.SiteBasic(c).ID, settings.SiteURL(setting.UseFirstSiteUrl(c)).String()), + request.WithCredential(auth.HMACAuth{[]byte(service.Node.SlaveKey)}, int64(settings.SlaveRequestSignTTL(c))), + request.WithEndpoint(service.Node.Server), + ), service.Node.Settings) + } -// Toggle 开关节点 -func (service *ToggleNodeService) Toggle() serializer.Response { - node, err := model.GetNodeByID(service.ID) if err != nil { - return serializer.DBErr("Node not found", err) + return "", serializer.NewError(serializer.CodeParamErr, "Failed to create downloader", err) } - // 是否为系统节点 - if node.ID <= 1 { - return serializer.Err(serializer.CodeInvalidActionOnSystemNode, "", err) + version, err := dl.Test(c) + if err != nil { + return "", serializer.NewError(serializer.CodeParamErr, "Failed to test downloader: "+err.Error(), nil) } - if err = node.SetStatus(service.Desired); err != nil { - return serializer.DBErr("Failed to change node status", err) + return version, nil +} + +type ( + UpsertNodeService struct { + Node *ent.Node `json:"node" binding:"required"` } + UpsertNodeParamCtx struct{} +) - if service.Desired == model.NodeActive { - cluster.Default.Add(&node) - } else { - cluster.Default.Delete(node.ID) +func (s *UpsertNodeService) Update(c *gin.Context) (*GetNodeResponse, error) { + dep := dependency.FromContext(c) + nodeClient := dep.NodeClient() + + if s.Node.ID == 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID is required", nil) } - return serializer.Response{} -} + node, err := nodeClient.Upsert(c, s.Node) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update node", err) + } + + // reload node pool + np, err := dep.NodePool(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err) + } + np.Upsert(c, node) + + // Clear policy cache since some this node maybe cached by some storage policy + kv := dep.KV() + kv.Delete(inventory.StoragePolicyCacheKey) -// NodeService 节点ID服务 -type NodeService struct { - ID uint `uri:"id" json:"id" binding:"required"` + service := &SingleNodeService{ID: node.ID} + return service.Get(c) } -// Delete 删除节点 -func (service *NodeService) Delete() serializer.Response { - // 查找用户组 - node, err := model.GetNodeByID(service.ID) - if err != nil { - return serializer.DBErr("Node record not found", err) +func (s *UpsertNodeService) Create(c *gin.Context) (*GetNodeResponse, error) { + dep := dependency.FromContext(c) + nodeClient := dep.NodeClient() + + if s.Node.ID > 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID must be 0", nil) } - // 是否为系统节点 - if node.ID <= 1 { - return serializer.Err(serializer.CodeInvalidActionOnSystemNode, "", err) + node, err := nodeClient.Upsert(c, s.Node) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create node", err) } - cluster.Default.Delete(node.ID) - if err := model.DB.Delete(&node).Error; err != nil { - return serializer.DBErr("Failed to delete node record", err) + // reload node pool + np, err := dep.NodePool(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err) } + np.Upsert(c, node) - return serializer.Response{} + service := &SingleNodeService{ID: node.ID} + return service.Get(c) } -// Get 获取节点详情 -func (service *NodeService) Get() serializer.Response { - node, err := model.GetNodeByID(service.ID) +func (s *SingleNodeService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + nodeClient := dep.NodeClient() + + ctx := context.WithValue(c, inventory.LoadNodeStoragePolicy{}, true) + existing, err := nodeClient.GetNodeById(ctx, s.ID) if err != nil { - return serializer.DBErr("Node not exist", err) + return serializer.NewError(serializer.CodeDBError, "Failed to get node", err) + } + + if existing.Type == node.TypeMaster { + return serializer.NewError(serializer.CodeInvalidActionOnSystemNode, "", nil) + } + + if len(existing.Edges.StoragePolicy) > 0 { + return serializer.NewError( + serializer.CodeNodeUsedByStoragePolicy, + strings.Join(lo.Map(existing.Edges.StoragePolicy, func(i *ent.StoragePolicy, _ int) string { + return i.Name + }), ", "), + nil, + ) } - return serializer.Response{Data: node} + // insert dummpy disabled node in nodepool to evict it + disabledNode := &ent.Node{ + ID: s.ID, + Type: node.TypeSlave, + Status: node.StatusSuspended, + } + np, err := dep.NodePool(c) + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err) + } + np.Upsert(c, disabledNode) + return nodeClient.Delete(c, s.ID) } diff --git a/service/admin/policy.go b/service/admin/policy.go index 478203ad..3f24bcf2 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -1,32 +1,33 @@ package admin import ( - "bytes" "context" - "encoding/json" + "errors" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" - "net/http" "net/url" - "os" - "path/filepath" "strconv" "strings" "time" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/cos" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/obs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/onedrive" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/s3" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" - cossdk "github.com/tencentyun/cos-go-sdk-v5" ) // PathTestService 本地路径测试服务 @@ -40,14 +41,17 @@ type SlaveTestService struct { Server string `json:"server" binding:"required"` } -// SlavePingService 从机相应ping -type SlavePingService struct { - Callback string `json:"callback" binding:"required"` -} +type ( + SlavePingParameterCtx struct{} + // SlavePingService ping slave node + SlavePingService struct { + Callback string `json:"callback" binding:"required"` + } +) // AddPolicyService 存储策略添加服务 type AddPolicyService struct { - Policy model.Policy `json:"policy" binding:"required"` + //Policy model.Policy `json:"policy" binding:"required"` } // PolicyService 存储策略ID服务 @@ -57,301 +61,444 @@ type PolicyService struct { } // Delete 删除存储策略 -func (service *PolicyService) Delete() serializer.Response { +func (service *SingleStoragePolicyService) Delete(c *gin.Context) error { // 禁止删除默认策略 if service.ID == 1 { - return serializer.Err(serializer.CodeDeleteDefaultPolicy, "", nil) + return serializer.NewError(serializer.CodeDeleteDefaultPolicy, "", nil) + } + + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() + + ctx := context.WithValue(c, inventory.LoadStoragePolicyGroup{}, true) + ctx = context.WithValue(ctx, inventory.SkipStoragePolicyCache{}, true) + policy, err := storagePolicyClient.GetPolicyByID(ctx, service.ID) + if err != nil { + return serializer.NewError(serializer.CodePolicyNotExist, "", err) + } + + // If policy is used by groups, return error + if len(policy.Edges.Groups) > 0 { + return serializer.NewError(serializer.CodePolicyUsedByGroups, strconv.Itoa(len(policy.Edges.Groups)), nil) } - policy, err := model.GetPolicyByID(service.ID) + used, err := dep.FileClient().IsStoragePolicyUsedByEntities(ctx, service.ID) if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) + return serializer.NewError(serializer.CodeDBError, "Failed to check if policy is used by entities", err) } - // 检查是否有文件使用 - total := 0 - row := model.DB.Model(&model.File{}).Where("policy_id = ?", service.ID). - Select("count(id)").Row() - row.Scan(&total) - if total > 0 { - return serializer.Err(serializer.CodePolicyUsedByFiles, strconv.Itoa(total), nil) + if used { + return serializer.NewError(serializer.CodePolicyUsedByFiles, "", nil) } - // 检查用户组使用 - var groups []model.Group - model.DB.Model(&model.Group{}).Where( - "policies like ?", - fmt.Sprintf("%%[%d]%%", service.ID), - ).Find(&groups) + err = storagePolicyClient.Delete(ctx, policy) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete policy", err) + } - if len(groups) > 0 { - return serializer.Err(serializer.CodePolicyUsedByGroups, strconv.Itoa(len(groups)), nil) + return nil +} + +// Test 从机响应ping +func (service *SlavePingService) Test(c *gin.Context) error { + master, err := url.Parse(service.Callback) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "Failed to parse callback url", err) } - model.DB.Delete(&policy) - policy.ClearCache() + dep := dependency.FromContext(c) + r := dep.RequestClient() + res, err := r.Request( + "GET", + routes.MasterPingUrl(master).String(), + nil, + request.WithContext(c), + request.WithLogger(logging.FromContext(c)), + request.WithCorrelationID(), + request.WithTimeout(time.Duration(10)*time.Second), + ).DecodeResponse() + + if err != nil { + return serializer.NewError(serializer.CodeSlavePingMaster, err.Error(), nil) + } + + version := constants.BackendVersion + + if strings.TrimSuffix(res.Data.(string), "-pro") != version { + return serializer.NewError(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil) + } + + return nil +} + +// Test 测试从机通信 +func (service *SlaveTestService) Test() serializer.Response { + //slave, err := url.Parse(service.Server) + //if err != nil { + // return serializer.ParamErrDeprecated("Failed to parse slave node server URL: "+err.Error(), nil) + //} + // + //controller, _ := url.Parse("/api/v3/slave/ping") + // + //// 请求正文 + //body := map[string]string{ + // "callback": model.GetSiteURL().String(), + //} + //bodyByte, _ := json.Marshal(body) + // + //r := request.NewClientDeprecated() + //res, err := r.Request( + // "POST", + // slave.ResolveReference(controller).String(), + // bytes.NewReader(bodyByte), + // request.WithTimeout(time.Duration(10)*time.Second), + // request.WithCredential( + // auth.HMACAuth{SecretKey: []byte(service.Secret)}, + // int64(model.GetIntSetting("slave_api_timeout", 60)), + // ), + //).DecodeResponse() + //if err != nil { + // return serializer.ParamErrDeprecated("Failed to connect to slave node: "+err.Error(), nil) + //} + // + //if res.Code != 0 { + // return serializer.ParamErrDeprecated("Successfully connected to slave node, but slave returns: "+res.Msg, nil) + //} + + return serializer.Response{} +} + +// Test 测试本地路径 +func (service *PathTestService) Test() serializer.Response { + //policy := model.Policy{DirNameRule: service.Path} + //path := policy.GeneratePath(1, "/My File") + //path = filepath.Join(path, "test.txt") + //file, err := util.CreatNestedFile(util.RelativePath(path)) + //if err != nil { + // return serializer.ParamErrDeprecated(fmt.Sprintf("Failed to create \"%s\": %s", path, err.Error()), nil) + //} + // + //file.Close() + //os.Remove(path) return serializer.Response{} } -// Get 获取存储策略详情 -func (service *PolicyService) Get() serializer.Response { - policy, err := model.GetPolicyByID(service.ID) +const ( + policyTypeCondition = "policy_type" +) + +// Policies 列出存储策略 +func (service *AdminListService) Policies(c *gin.Context) (*ListPolicyResponse, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() + + ctx := context.WithValue(c, inventory.LoadStoragePolicyGroup{}, true) + res, err := storagePolicyClient.ListPolicies(ctx, &inventory.ListPolicyParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: service.Page - 1, + PageSize: service.PageSize, + OrderBy: service.OrderBy, + Order: inventory.OrderDirection(service.OrderDirection), + }, + Type: types.PolicyType(service.Conditions[policyTypeCondition]), + }) + if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list policies", err) } - return serializer.Response{Data: policy} + return &ListPolicyResponse{ + Pagination: res.PaginationResults, + Policies: res.Policies, + }, nil } -// GetOAuth 获取 OneDrive OAuth 地址 -func (service *PolicyService) GetOAuth(c *gin.Context, policyType string) serializer.Response { - policy, err := model.GetPolicyByID(service.ID) - if err != nil || policy.Type != policyType { - return serializer.Err(serializer.CodePolicyNotExist, "", nil) +type ( + SingleStoragePolicyService struct { + ID int `uri:"id" json:"id" binding:"required"` } + GetStoragePolicyParamCtx struct{} +) - util.SetSession(c, map[string]interface{}{ - policyType + "_oauth_policy": policy.ID, - }) +const ( + countEntityQuery = "countEntity" +) - var redirect string - switch policy.Type { - case "onedrive": - client, err := onedrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) - } +func (service *SingleStoragePolicyService) Get(c *gin.Context) (*GetStoragePolicyResponse, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() - redirect = client.OAuthURL(context.Background(), []string{ - "offline_access", - "files.readwrite.all", - }) - case "googledrive": - client, err := googledrive.NewClient(&policy) + ctx := context.WithValue(c, inventory.LoadStoragePolicyGroup{}, true) + ctx = context.WithValue(ctx, inventory.SkipStoragePolicyCache{}, true) + policy, err := storagePolicyClient.GetPolicyByID(ctx, service.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get policy", err) + } + + res := &GetStoragePolicyResponse{StoragePolicy: policy} + if c.Query(countEntityQuery) != "" { + count, size, err := dep.FileClient().CountEntityByStoragePolicyID(ctx, service.ID) if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to count entities", err) } + res.EntitiesCount = count + res.EntitiesSize = size + } - redirect = client.OAuthURL(context.Background(), googledrive.RequiredScope) + return res, nil +} + +type ( + CreateStoragePolicyService struct { + Policy *ent.StoragePolicy `json:"policy" binding:"required"` } + CreateStoragePolicyParamCtx struct{} +) - // Delete token cache - cache.Deletes([]string{policy.BucketName}, policyType+"_") +func (service *CreateStoragePolicyService) Create(c *gin.Context) (*GetStoragePolicyResponse, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() + + if service.Policy.Type == types.PolicyTypeLocal { + service.Policy.DirNameRule = util.DataPath("uploads/{uid}/{path}") + } - return serializer.Response{Data: redirect} + service.Policy.ID = 0 + policy, err := storagePolicyClient.Upsert(c, service.Policy) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create policy", err) + } + + return &GetStoragePolicyResponse{StoragePolicy: policy}, nil } -// AddSCF 创建回调云函数 -func (service *PolicyService) AddSCF() serializer.Response { - policy, err := model.GetPolicyByID(service.ID) +type ( + UpdateStoragePolicyService struct { + Policy *ent.StoragePolicy `json:"policy" binding:"required"` + } + UpdateStoragePolicyParamCtx struct{} +) + +func (service *UpdateStoragePolicyService) Update(c *gin.Context) (*GetStoragePolicyResponse, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() + + id := c.Param("id") + if id == "" { + return nil, serializer.NewError(serializer.CodeParamErr, "ID is required", nil) + } + idInt, err := strconv.Atoi(id) if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", nil) + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid ID", err) } - if err := cos.CreateSCF(&policy, service.Region); err != nil { - return serializer.ParamErr("Failed to create SCF function", err) + service.Policy.ID = idInt + _, err = storagePolicyClient.Upsert(c, service.Policy) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update policy", err) } - return serializer.Response{} + _ = dep.KV().Delete(manager.EntityUrlCacheKeyPrefix) + + s := SingleStoragePolicyService{ID: idInt} + return s.Get(c) } -// AddCORS 创建跨域策略 -func (service *PolicyService) AddCORS() serializer.Response { - policy, err := model.GetPolicyByID(service.ID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", nil) +type ( + CreateStoragePolicyCorsService struct { + Policy *ent.StoragePolicy `json:"policy" binding:"required"` } + CreateStoragePolicyCorsParamCtx struct{} +) + +func (service *CreateStoragePolicyCorsService) Create(c *gin.Context) error { + dep := dependency.FromContext(c) - switch policy.Type { - case "oss": - handler, err := oss.NewDriver(&policy) + switch service.Policy.Type { + case types.PolicyTypeOss: + handler, err := oss.New(c, service.Policy, dep.SettingProvider(), dep.ConfigProvider(), dep.Logger(), dep.MimeDetector(c)) if err != nil { - return serializer.Err(serializer.CodeAddCORS, "", err) + return serializer.NewError(serializer.CodeDBError, "Failed to create oss driver", err) } if err := handler.CORS(); err != nil { - return serializer.Err(serializer.CodeAddCORS, "", err) + return serializer.NewError(serializer.CodeInternalSetting, "Failed to create cors: "+err.Error(), err) + } + + return nil + + case types.PolicyTypeCos: + handler, err := cos.New(c, service.Policy, dep.SettingProvider(), dep.ConfigProvider(), dep.Logger(), dep.MimeDetector(c)) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to create cos driver", err) + } + + if err := handler.CORS(); err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to create cors: "+err.Error(), err) } - case "cos": - u, _ := url.Parse(policy.Server) - b := &cossdk.BaseURL{BucketURL: u} - handler := cos.Driver{ - Policy: &policy, - HTTPClient: request.NewClient(), - Client: cossdk.NewClient(b, &http.Client{ - Transport: &cossdk.AuthorizationTransport{ - SecretID: policy.AccessKey, - SecretKey: policy.SecretKey, - }, - }), + + return nil + + case types.PolicyTypeS3: + handler, err := s3.New(c, service.Policy, dep.SettingProvider(), dep.ConfigProvider(), dep.Logger(), dep.MimeDetector(c)) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to create s3 driver", err) } if err := handler.CORS(); err != nil { - return serializer.Err(serializer.CodeAddCORS, "", err) + return serializer.NewError(serializer.CodeInternalSetting, "Failed to create cors: "+err.Error(), err) } - case "s3": - handler, err := s3.NewDriver(&policy) + + return nil + + case types.PolicyTypeObs: + handler, err := obs.New(c, service.Policy, dep.SettingProvider(), dep.ConfigProvider(), dep.Logger(), dep.MimeDetector(c)) if err != nil { - return serializer.Err(serializer.CodeAddCORS, "", err) + return serializer.NewError(serializer.CodeDBError, "Failed to create obs driver", err) } if err := handler.CORS(); err != nil { - return serializer.Err(serializer.CodeAddCORS, "", err) + return serializer.NewError(serializer.CodeInternalSetting, "Failed to create cors: "+err.Error(), err) } + + return nil default: - return serializer.Err(serializer.CodePolicyNotAllowed, "", nil) + return serializer.NewError(serializer.CodeParamErr, "Unsupported policy type", nil) } - - return serializer.Response{} } -// Test 从机响应ping -func (service *SlavePingService) Test() serializer.Response { - master, err := url.Parse(service.Callback) - if err != nil { - return serializer.ParamErr("Failed to parse Master site url: "+err.Error(), nil) +type ( + GetOauthRedirectService struct { + ID int `json:"id" binding:"required"` + Secret string `json:"secret" binding:"required"` + AppID string `json:"app_id" binding:"required"` } + GetOauthRedirectParamCtx struct{} +) - controller, _ := url.Parse("/api/v3/site/ping") +// GetOAuth 获取 OneDrive OAuth 地址 +func (service *GetOauthRedirectService) GetOAuth(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() - r := request.NewClient() - res, err := r.Request( - "GET", - master.ResolveReference(controller).String(), - nil, - request.WithTimeout(time.Duration(10)*time.Second), - ).DecodeResponse() + policy, err := storagePolicyClient.GetPolicyByID(c, service.ID) + if err != nil || policy.Type != types.PolicyTypeOd { + return "", serializer.NewError(serializer.CodePolicyNotExist, "", nil) + } + // Update to latest redirect url + policy.Settings.OauthRedirect = routes.MasterPolicyOAuthCallback(dep.SettingProvider().SiteURL(c)).String() + policy.SecretKey = service.Secret + policy.BucketName = service.AppID + policy, err = storagePolicyClient.Upsert(c, policy) if err != nil { - return serializer.Err(serializer.CodeSlavePingMaster, err.Error(), nil) + return "", serializer.NewError(serializer.CodeDBError, "Failed to update policy", err) } - version := conf.BackendVersion - if conf.IsPro == "true" { - version += "-pro" - } - if res.Data.(string) != version { - return serializer.Err(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil) - } + client := onedrive.NewClient(policy, dep.RequestClient(), dep.CredManager(), dep.Logger(), dep.SettingProvider(), 0) + redirect := client.OAuthURL(context.Background(), []string{ + "offline_access", + "files.readwrite.all", + }) - return serializer.Response{} + return redirect, nil } -// Test 测试从机通信 -func (service *SlaveTestService) Test() serializer.Response { - slave, err := url.Parse(service.Server) - if err != nil { - return serializer.ParamErr("Failed to parse slave node server URL: "+err.Error(), nil) - } +func GetPolicyOAuthURL(c *gin.Context) string { + dep := dependency.FromContext(c) + return routes.MasterPolicyOAuthCallback(dep.SettingProvider().SiteURL(c)).String() +} - controller, _ := url.Parse("/api/v3/slave/ping") +// GetOauthCredentialStatus returns last refresh time of oauth credential +func (service *SingleStoragePolicyService) GetOauthCredentialStatus(c *gin.Context) (*OauthCredentialStatus, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() - // 请求正文 - body := map[string]string{ - "callback": model.GetSiteURL().String(), + policy, err := storagePolicyClient.GetPolicyByID(c, service.ID) + if err != nil || policy.Type != types.PolicyTypeOd { + return nil, serializer.NewError(serializer.CodePolicyNotExist, "", nil) } - bodyByte, _ := json.Marshal(body) - r := request.NewClient() - res, err := r.Request( - "POST", - slave.ResolveReference(controller).String(), - bytes.NewReader(bodyByte), - request.WithTimeout(time.Duration(10)*time.Second), - request.WithCredential( - auth.HMACAuth{SecretKey: []byte(service.Secret)}, - int64(model.GetIntSetting("slave_api_timeout", 60)), - ), - ).DecodeResponse() - if err != nil { - return serializer.ParamErr("Failed to connect to slave node: "+err.Error(), nil) + if policy.AccessKey == "" { + return &OauthCredentialStatus{Valid: false}, nil } - if res.Code != 0 { - return serializer.ParamErr("Successfully connected to slave node, but slave returns: "+res.Msg, nil) + token, err := dep.CredManager().Obtain(c, onedrive.CredentialKey(policy.ID)) + if err != nil { + if errors.Is(err, credmanager.ErrNotFound) { + return &OauthCredentialStatus{Valid: false}, nil + } + + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get credential", err) } - return serializer.Response{} + return &OauthCredentialStatus{Valid: true, LastRefreshTime: token.RefreshedAt()}, nil } -// Add 添加存储策略 -func (service *AddPolicyService) Add() serializer.Response { - if service.Policy.Type != "local" && service.Policy.Type != "remote" { - service.Policy.DirNameRule = strings.TrimPrefix(service.Policy.DirNameRule, "/") - } - - if service.Policy.ID > 0 { - if err := model.DB.Save(&service.Policy).Error; err != nil { - return serializer.DBErr("Failed to save policy", err) - } - } else { - if err := model.DB.Create(&service.Policy).Error; err != nil { - return serializer.DBErr("Failed to create policy", err) - } +type ( + FinishOauthCallbackService struct { + Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` } + FinishOauthCallbackParamCtx struct{} +) - service.Policy.ClearCache() - - return serializer.Response{Data: service.Policy.ID} -} +func (service *FinishOauthCallbackService) Finish(c *gin.Context) error { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() -// Test 测试本地路径 -func (service *PathTestService) Test() serializer.Response { - policy := model.Policy{DirNameRule: service.Path} - path := policy.GeneratePath(1, "/My File") - path = filepath.Join(path, "test.txt") - file, err := util.CreatNestedFile(util.RelativePath(path)) + policyId, err := strconv.Atoi(service.State) if err != nil { - return serializer.ParamErr(fmt.Sprintf("Failed to create \"%s\": %s", path, err.Error()), nil) + return serializer.NewError(serializer.CodeParamErr, "Invalid state", err) } - file.Close() - os.Remove(path) + policy, err := storagePolicyClient.GetPolicyByID(c, policyId) + if err != nil { + return serializer.NewError(serializer.CodePolicyNotExist, "", nil) + } - return serializer.Response{} -} + if policy.Type != types.PolicyTypeOd { + return serializer.NewError(serializer.CodeParamErr, "Invalid policy type", nil) + } -// Policies 列出存储策略 -func (service *AdminListService) Policies() serializer.Response { - var res []model.Policy - total := 0 + client := onedrive.NewClient(policy, dep.RequestClient(), dep.CredManager(), dep.Logger(), dep.SettingProvider(), 0) + credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code)) + if err != nil { + return serializer.NewError(serializer.CodeIncorrectPassword, "Failed to obtain token", err) + } - tx := model.DB.Model(&model.Policy{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) + credManager := dep.CredManager() + err = credManager.Upsert(c, credential) + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to upsert credential", err) } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + _, err = credManager.Obtain(c, onedrive.CredentialKey(policy.ID)) + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to obtain credential", err) } - // 计算总数用于分页 - tx.Count(&total) + return nil +} - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) +func (service *SingleStoragePolicyService) GetSharePointDriverRoot(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + storagePolicyClient := dep.StoragePolicyClient() - // 统计每个策略的文件使用 - statics := make(map[uint][2]int, len(res)) - policyIds := make([]uint, 0, len(res)) - for i := 0; i < len(res); i++ { - policyIds = append(policyIds, res[i].ID) + policy, err := storagePolicyClient.GetPolicyByID(c, service.ID) + if err != nil { + return "", serializer.NewError(serializer.CodePolicyNotExist, "", nil) } - rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds). - Select("policy_id,count(id),sum(size)").Group("policy_id").Rows() - - for rows.Next() { - policyId := uint(0) - total := [2]int{} - rows.Scan(&policyId, &total[0], &total[1]) + if policy.Type != types.PolicyTypeOd { + return "", serializer.NewError(serializer.CodeParamErr, "Invalid policy type", nil) + } - statics[policyId] = total + client := onedrive.NewClient(policy, dep.RequestClient(), dep.CredManager(), dep.Logger(), dep.SettingProvider(), 0) + root, err := client.GetSiteIDByURL(c, c.Query("url")) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to get site id", err) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "statics": statics, - }} + return fmt.Sprintf("sites/%s/drive", root), nil } diff --git a/service/admin/response.go b/service/admin/response.go new file mode 100644 index 00000000..1b2a8f6e --- /dev/null +++ b/service/admin/response.go @@ -0,0 +1,141 @@ +package admin + +import ( + "encoding/gob" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" +) + +type ListShareResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Shares []GetShareResponse `json:"shares"` +} + +type GetShareResponse struct { + *ent.Share + UserHashID string `json:"user_hash_id,omitempty"` + ShareLink string `json:"share_link,omitempty"` +} + +type ListTaskResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Tasks []GetTaskResponse `json:"tasks"` +} + +type GetTaskResponse struct { + *ent.Task + UserHashID string `json:"user_hash_id,omitempty"` + Summary *queue.Summary `json:"summary,omitempty"` + Node *ent.Node `json:"node,omitempty"` +} + +type ListEntityResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Entities []GetEntityResponse `json:"entities"` +} + +type GetEntityResponse struct { + *ent.Entity + UserHashID string `json:"user_hash_id,omitempty"` + UserHashIDMap map[int]string `json:"user_hash_id_map,omitempty"` +} + +type ListFileResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Files []GetFileResponse `json:"files"` +} + +type GetFileResponse struct { + *ent.File + UserHashID string `json:"user_hash_id,omitempty"` + DirectLinkMap map[int]string `json:"direct_link_map,omitempty"` +} + +type ListUserResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Users []GetUserResponse `json:"users"` +} + +type GetUserResponse struct { + *ent.User + HashID string `json:"hash_id,omitempty"` + TwoFAEnabled bool `json:"two_fa_enabled,omitempty"` + Capacity *fs.Capacity `json:"capacity,omitempty"` +} + +type GetNodeResponse struct { + *ent.Node +} + +type GetGroupResponse struct { + *ent.Group + TotalUsers int `json:"total_users"` +} + +type OauthCredentialStatus struct { + Valid bool `json:"valid"` + LastRefreshTime *time.Time `json:"last_refresh_time"` +} + +type GetStoragePolicyResponse struct { + *ent.StoragePolicy + EntitiesCount int `json:"entities_count,omitempty"` + EntitiesSize int `json:"entities_size,omitempty"` +} + +type ListNodeResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Nodes []*ent.Node `json:"nodes"` +} + +type ListPolicyResponse struct { + Pagination *inventory.PaginationResults `json:"pagination"` + Policies []*ent.StoragePolicy `json:"policies"` +} + +type QueueMetric struct { + Name setting.QueueType `json:"name"` + BusyWorkers int `json:"busy_workers"` + SuccessTasks int `json:"success_tasks"` + FailureTasks int `json:"failure_tasks"` + SubmittedTasks int `json:"submitted_tasks"` + SuspendingTasks int `json:"suspending_tasks"` +} + +type ListGroupResponse struct { + Groups []*ent.Group `json:"groups"` + Pagination *inventory.PaginationResults `json:"pagination"` +} + +type HomepageSummary struct { + MetricsSummary *MetricsSummary `json:"metrics_summary"` + SiteURls []string `json:"site_urls"` + Version *Version `json:"version"` +} + +type MetricsSummary struct { + Dates []time.Time `json:"dates"` + Files []int `json:"files"` + Users []int `json:"users"` + Shares []int `json:"shares"` + FileTotal int `json:"file_total"` + UserTotal int `json:"user_total"` + ShareTotal int `json:"share_total"` + EntitiesTotal int `json:"entities_total"` + GeneratedAt time.Time `json:"generated_at"` +} + +type Version struct { + Version string `json:"version"` + Pro bool `json:"pro"` + Commit string `json:"commit"` +} + +func init() { + gob.Register(MetricsSummary{}) +} diff --git a/service/admin/share.go b/service/admin/share.go index 66d89fae..af88e529 100644 --- a/service/admin/share.go +++ b/service/admin/share.go @@ -1,80 +1,145 @@ package admin import ( - "strings" + "context" + "strconv" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) -// ShareBatchService 分享批量操作服务 -type ShareBatchService struct { - ID []uint `json:"id" binding:"min=1"` -} +const ( + shareUserIDCondition = "share_user_id" + shareFileIDCondition = "share_file_id" +) -// Delete 删除文件 -func (service *ShareBatchService) Delete(c *gin.Context) serializer.Response { - if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Share{}).Error; err != nil { - return serializer.DBErr("Failed to delete share record", err) - } - return serializer.Response{} -} +func (s *AdminListService) Shares(c *gin.Context) (*ListShareResponse, error) { + dep := dependency.FromContext(c) + shareClient := dep.ShareClient() + hasher := dep.HashIDEncoder() -// Shares 列出分享 -func (service *AdminListService) Shares() serializer.Response { - var res []model.Share - total := 0 + var ( + err error + userID int + fileID int + ) - tx := model.DB.Model(&model.Share{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) + if s.Conditions[shareUserIDCondition] != "" { + userID, err = strconv.Atoi(s.Conditions[shareUserIDCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid share user ID", err) + } } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + if s.Conditions[shareFileIDCondition] != "" { + fileID, err = strconv.Atoi(s.Conditions[shareFileIDCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid share file ID", err) + } } - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += k + " like '%" + v + "%' OR " - } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + ctx := context.WithValue(c, inventory.LoadShareFile{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareUser{}, true) + + res, err := shareClient.List(ctx, &inventory.ListShareArgs{ + PaginationArgs: &inventory.PaginationArgs{ + Page: s.Page - 1, + PageSize: s.PageSize, + OrderBy: s.OrderBy, + Order: inventory.OrderDirection(s.OrderDirection), + }, + UserID: userID, + FileID: fileID, + }) + + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list shares", err) } - // 计算总数用于分页 - tx.Count(&total) + siteUrl := dep.SettingProvider().SiteURL(c) + + return &ListShareResponse{ + Pagination: res.PaginationResults, + Shares: lo.Map(res.Shares, func(share *ent.Share, _ int) GetShareResponse { + var ( + uid string + shareLink string + ) - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + if share.Edges.User != nil { + uid = hashid.EncodeUserID(hasher, share.Edges.User.ID) + } - // 查询对应用户,同时计算HashID - users := make(map[uint]model.User) - hashIDs := make(map[uint]string, len(res)) - for _, file := range res { - users[file.UserID] = model.User{} - hashIDs[file.ID] = hashid.HashID(file.ID, hashid.ShareID) + shareLink = routes.MasterShareUrl(siteUrl, hashid.EncodeShareID(hasher, share.ID), share.Password).String() + + return GetShareResponse{ + Share: share, + UserHashID: uid, + ShareLink: shareLink, + } + }), + }, nil + +} + +type ( + SingleShareService struct { + ShareID int `uri:"id" binding:"required"` + } + SingleShareParamCtx struct{} +) + +func (s *SingleShareService) Get(c *gin.Context) (*GetShareResponse, error) { + dep := dependency.FromContext(c) + shareClient := dep.ShareClient() + hasher := dep.HashIDEncoder() + + ctx := context.WithValue(c, inventory.LoadShareFile{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareUser{}, true) + share, err := shareClient.GetByID(ctx, s.ShareID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get share", err) + } + + var ( + uid string + shareLink string + ) + + if share.Edges.User != nil { + uid = hashid.EncodeShareID(hasher, share.Edges.User.ID) } - userIDs := make([]uint, 0, len(users)) - for k := range users { - userIDs = append(userIDs, k) + siteUrl := dep.SettingProvider().SiteURL(c) + shareLink = routes.MasterShareUrl(siteUrl, hashid.EncodeShareID(hasher, share.ID), share.Password).String() + + return &GetShareResponse{ + Share: share, + UserHashID: uid, + ShareLink: shareLink, + }, nil +} + +type ( + BatchShareService struct { + ShareIDs []int `json:"ids" binding:"required"` } + BatchShareParamCtx struct{} +) - var userList []model.User - model.DB.Where("id in (?)", userIDs).Find(&userList) +func (s *BatchShareService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + shareClient := dep.ShareClient() - for _, v := range userList { - users[v.ID] = v + if err := shareClient.DeleteBatch(c, s.ShareIDs); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete shares", err) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "users": users, - "ids": hashIDs, - }} + return nil } diff --git a/service/admin/site.go b/service/admin/site.go index 69aa2d8f..7b5a6332 100644 --- a/service/admin/site.go +++ b/service/admin/site.go @@ -1,16 +1,25 @@ package admin import ( + "context" "encoding/gob" + "encoding/json" + "fmt" + "net/url" + "reflect" + "strings" "time" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/thumb" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) func init() { @@ -33,133 +42,371 @@ type SettingChangeService struct { Value string `json:"value"` } -// BatchSettingGet 设定批量获取服务 -type BatchSettingGet struct { - Keys []string `json:"keys"` +// Change 批量更改站点设定 +func (service *BatchSettingChangeService) Change() serializer.Response { + //cacheClean := make([]string, 0, len(service.Options)) + //tx := model.DB.Begin() + // + //for _, setting := range service.Options { + // + // if err := tx.Model(&model.Setting{}).Where("name = ?", setting.Key).Update("value", setting.Value).Error; err != nil { + // cache.Deletes(cacheClean, "setting_") + // tx.Rollback() + // return serializer.ErrDeprecated(serializer.CodeUpdateSetting, "Setting "+setting.Key+" failed to update", err) + // } + // + // cacheClean = append(cacheClean, setting.Key) + //} + // + //if err := tx.Commit().Error; err != nil { + // return serializer.DBErrDeprecated("Failed to update setting", err) + //} + // + //cache.Deletes(cacheClean, "setting_") + + return serializer.Response{} } -// MailTestService 邮件测试服务 -type MailTestService struct { - Email string `json:"to" binding:"email"` +const ( + SummaryRangeDays = 12 + MetricCacheKey = "admin_summary" + metricErrMsg = "Failed to generate metrics summary" +) + +type ( + SummaryService struct { + Generate bool `form:"generate"` + } + SummaryParamCtx struct{} +) + +// Summary 获取站点统计概况 +func (s *SummaryService) Summary(c *gin.Context) (*HomepageSummary, error) { + dep := dependency.FromContext(c) + kv := dep.KV() + res := &HomepageSummary{ + Version: &Version{ + Version: constants.BackendVersion, + Pro: constants.IsProBool, + Commit: constants.LastCommit, + }, + SiteURls: lo.Map(dep.SettingProvider().AllSiteURLs(c), func(item *url.URL, index int) string { + return item.String() + }), + } + + if summary, ok := kv.Get(MetricCacheKey); ok { + summaryCasted := summary.(MetricsSummary) + res.MetricsSummary = &summaryCasted + return res, nil + } + + if !s.Generate { + return res, nil + } + + summary := &MetricsSummary{ + Files: make([]int, SummaryRangeDays), + Users: make([]int, SummaryRangeDays), + Shares: make([]int, SummaryRangeDays), + Dates: make([]time.Time, SummaryRangeDays), + GeneratedAt: time.Now(), + } + + fileClient := dep.FileClient() + userClient := dep.UserClient() + shareClient := dep.ShareClient() + + toRound := time.Now() + timeBase := time.Date(toRound.Year(), toRound.Month(), toRound.Day()+1, 0, 0, 0, 0, toRound.Location()) + for day := range summary.Files { + start := timeBase.Add(-time.Duration(SummaryRangeDays-day) * time.Hour * 24) + end := timeBase.Add(-time.Duration(SummaryRangeDays-day-1) * time.Hour * 24) + summary.Dates[day] = start + fileTotal, err := fileClient.CountByTimeRange(c, &start, &end) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + userTotal, err := userClient.CountByTimeRange(c, &start, &end) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + shareTotal, err := shareClient.CountByTimeRange(c, &start, &end) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + summary.Files[day] = fileTotal + summary.Users[day] = userTotal + summary.Shares[day] = shareTotal + } + + var err error + summary.FileTotal, err = fileClient.CountByTimeRange(c, nil, nil) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + summary.UserTotal, err = userClient.CountByTimeRange(c, nil, nil) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + summary.ShareTotal, err = shareClient.CountByTimeRange(c, nil, nil) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + summary.EntitiesTotal, err = fileClient.CountEntityByTimeRange(c, nil, nil) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, metricErrMsg, nil) + } + + _ = kv.Set(MetricCacheKey, *summary, 86400) + res.MetricsSummary = summary + + return res, nil } -// Send 发送测试邮件 -func (service *MailTestService) Send() serializer.Response { - if err := email.Send(service.Email, "Cloudreve Email delivery test", "This is a test Email, to test Cloudreve Email delivery settings"); err != nil { - return serializer.Err(serializer.CodeFailedSendEmail, err.Error(), nil) +// ThumbGeneratorTestService 缩略图生成测试服务 +type ( + ThumbGeneratorTestService struct { + Name string `json:"name" binding:"required"` + Executable string `json:"executable" binding:"required"` } - return serializer.Response{} + ThumbGeneratorTestParamCtx struct{} +) + +// Test 通过获取生成器版本来测试 +func (s *ThumbGeneratorTestService) Test(c *gin.Context) (string, error) { + version, err := thumb.TestGenerator(c, s.Name, s.Executable) + if err != nil { + return "", serializer.NewError(serializer.CodeParamErr, "Failed to invoke generator: "+err.Error(), err) + } + + return version, nil } -// Get 获取设定值 -func (service *BatchSettingGet) Get() serializer.Response { - options := model.GetSettingByNames(service.Keys...) - return serializer.Response{Data: options} +type ( + GetSettingService struct { + Keys []string `json:"keys" binding:"required"` + } + GetSettingParamCtx struct{} +) + +func (s *GetSettingService) GetSetting(c *gin.Context) (map[string]string, error) { + dep := dependency.FromContext(c) + res, err := dep.SettingClient().Gets(c, lo.Filter(s.Keys, func(item string, index int) bool { + return item != "secret_key" + })) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get settings", err) + } + + return res, nil } -// Change 批量更改站点设定 -func (service *BatchSettingChangeService) Change() serializer.Response { - cacheClean := make([]string, 0, len(service.Options)) - tx := model.DB.Begin() +type ( + SetSettingService struct { + Settings map[string]string `json:"settings" binding:"required"` + } + SetSettingParamCtx struct{} + SettingPreProcessor func(ctx context.Context, settings map[string]string) error + SettingPostProcessor func(ctx context.Context, settings map[string]string) error +) - for _, setting := range service.Options { +var ( + preprocessors = map[string]SettingPreProcessor{ + "siteURL": siteUrlPreProcessor, + "mime_mapping": mimeMappingPreProcessor, + "secret_key": secretKeyPreProcessor, + } + postprocessors = map[string]SettingPostProcessor{ + "mime_mapping": mimeMappingPostProcessor, + "media_meta_exif": mediaMetaPostProcessor, + "media_meta_music": mediaMetaPostProcessor, + "media_meta_ffprobe": mediaMetaPostProcessor, + "smtpUser": emailPostProcessor, + "smtpPass": emailPostProcessor, + "smtpHost": emailPostProcessor, + "smtpPort": emailPostProcessor, + "smtpEncryption": emailPostProcessor, + "smtpFrom": emailPostProcessor, + "replyTo": emailPostProcessor, + "fromName": emailPostProcessor, + "fromAdress": emailPostProcessor, + "queue_media_meta_worker_num": mediaMetaQueuePostProcessor, + "queue_media_meta_max_execution": mediaMetaQueuePostProcessor, + "queue_media_meta_backoff_factor": mediaMetaQueuePostProcessor, + "queue_media_meta_backoff_max_duration": mediaMetaQueuePostProcessor, + "queue_media_meta_max_retry": mediaMetaQueuePostProcessor, + "queue_media_meta_retry_delay": mediaMetaQueuePostProcessor, + "queue_thumb_worker_num": thumbQueuePostProcessor, + "queue_thumb_max_execution": thumbQueuePostProcessor, + "queue_thumb_backoff_factor": thumbQueuePostProcessor, + "queue_thumb_backoff_max_duration": thumbQueuePostProcessor, + "queue_thumb_max_retry": thumbQueuePostProcessor, + "queue_thumb_retry_delay": thumbQueuePostProcessor, + "queue_recycle_worker_num": entityRecycleQueuePostProcessor, + "queue_recycle_max_execution": entityRecycleQueuePostProcessor, + "queue_recycle_backoff_factor": entityRecycleQueuePostProcessor, + "queue_recycle_backoff_max_duration": entityRecycleQueuePostProcessor, + "queue_recycle_max_retry": entityRecycleQueuePostProcessor, + "queue_recycle_retry_delay": entityRecycleQueuePostProcessor, + "queue_io_intense_worker_num": ioIntenseQueuePostProcessor, + "queue_io_intense_max_execution": ioIntenseQueuePostProcessor, + "queue_io_intense_backoff_factor": ioIntenseQueuePostProcessor, + "queue_io_intense_backoff_max_duration": ioIntenseQueuePostProcessor, + "queue_io_intense_max_retry": ioIntenseQueuePostProcessor, + "queue_io_intense_retry_delay": ioIntenseQueuePostProcessor, + "queue_remote_download_worker_num": remoteDownloadQueuePostProcessor, + "queue_remote_download_max_execution": remoteDownloadQueuePostProcessor, + "queue_remote_download_backoff_factor": remoteDownloadQueuePostProcessor, + "queue_remote_download_backoff_max_duration": remoteDownloadQueuePostProcessor, + "queue_remote_download_max_retry": remoteDownloadQueuePostProcessor, + "queue_remote_download_retry_delay": remoteDownloadQueuePostProcessor, + "secret_key": secretKeyPostProcessor, + } +) + +func (s *SetSettingService) SetSetting(c *gin.Context) (map[string]string, error) { + dep := dependency.FromContext(c) + kv := dep.KV() + settingClient := dep.SettingClient() - if err := tx.Model(&model.Setting{}).Where("name = ?", setting.Key).Update("value", setting.Value).Error; err != nil { - cache.Deletes(cacheClean, "setting_") - tx.Rollback() - return serializer.Err(serializer.CodeUpdateSetting, "Setting "+setting.Key+" failed to update", err) + // Preprocess settings + allPreprocessors := make(map[string]SettingPreProcessor) + allPostprocessors := make(map[string]SettingPostProcessor) + for k, _ := range s.Settings { + if preprocessor, ok := preprocessors[k]; ok { + fnName := reflect.TypeOf(preprocessor).Name() + if _, ok := allPreprocessors[fnName]; !ok { + allPreprocessors[fnName] = preprocessor + } } - cacheClean = append(cacheClean, setting.Key) + if postprocessor, ok := postprocessors[k]; ok { + fnName := reflect.TypeOf(postprocessor).Name() + if _, ok := allPostprocessors[fnName]; !ok { + allPostprocessors[fnName] = postprocessor + } + } } - if err := tx.Commit().Error; err != nil { - return serializer.DBErr("Failed to update setting", err) + // Execute all preprocessors + for _, preprocessor := range allPreprocessors { + if err := preprocessor(c, s.Settings); err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Failed to validate settings", err) + } } - cache.Deletes(cacheClean, "setting_") + // Save to db + sc, tx, ctx, err := inventory.WithTx(c, settingClient) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create transaction", err) + } - return serializer.Response{} -} + if err := sc.Set(ctx, s.Settings); err != nil { + _ = inventory.Rollback(tx) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to save settings", err) + } -// Summary 获取站点统计概况 -func (service *NoParamService) Summary() serializer.Response { - // 获取版本信息 - versions := map[string]string{ - "backend": conf.BackendVersion, - "db": conf.RequiredDBVersion, - "commit": conf.LastCommit, - "is_pro": conf.IsPro, + if err := inventory.Commit(tx); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit transaction", err) + } + + // Clean cache + if err := kv.Delete(setting.KvSettingPrefix, lo.Keys(s.Settings)...); err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to clear cache", err) } - if res, ok := cache.Get("admin_summary"); ok { - resMap := res.(map[string]interface{}) - resMap["version"] = versions - resMap["siteURL"] = model.GetSettingByName("siteURL") - return serializer.Response{Data: resMap} + // Execute post preprocessors + for _, postprocessor := range allPostprocessors { + if err := postprocessor(ctx, s.Settings); err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Failed to post process settings", err) + } } - // 统计每日概况 - total := 12 - files := make([]int, total) - users := make([]int, total) - shares := make([]int, total) - date := make([]string, total) + return s.Settings, nil +} - toRound := time.Now() - timeBase := time.Date(toRound.Year(), toRound.Month(), toRound.Day()+1, 0, 0, 0, 0, toRound.Location()) - for day := range files { - start := timeBase.Add(-time.Duration(total-day) * time.Hour * 24) - end := timeBase.Add(-time.Duration(total-day-1) * time.Hour * 24) - date[day] = start.Format("1月2日") - model.DB.Model(&model.User{}).Where("created_at BETWEEN ? AND ?", start, end).Count(&users[day]) - model.DB.Model(&model.File{}).Where("created_at BETWEEN ? AND ?", start, end).Count(&files[day]) - model.DB.Model(&model.Share{}).Where("created_at BETWEEN ? AND ?", start, end).Count(&shares[day]) - } - - // 统计总数 - fileTotal := 0 - userTotal := 0 - publicShareTotal := 0 - secretShareTotal := 0 - model.DB.Model(&model.User{}).Count(&userTotal) - model.DB.Model(&model.File{}).Count(&fileTotal) - model.DB.Model(&model.Share{}).Where("password = ?", "").Count(&publicShareTotal) - model.DB.Model(&model.Share{}).Where("password <> ?", "").Count(&secretShareTotal) - - resp := map[string]interface{}{ - "date": date, - "files": files, - "users": users, - "shares": shares, - "version": versions, - "siteURL": model.GetSettingByName("siteURL"), - "fileTotal": fileTotal, - "userTotal": userTotal, - "publicShareTotal": publicShareTotal, - "secretShareTotal": secretShareTotal, - } - - cache.Set("admin_summary", resp, 86400) - return serializer.Response{ - Data: resp, +func siteUrlPreProcessor(ctx context.Context, settings map[string]string) error { + siteURL := settings["siteURL"] + urls := strings.Split(siteURL, ",") + for index, u := range urls { + urlParsed, err := url.Parse(u) + if err != nil { + return fmt.Errorf("Failed to parse siteURL %q: %w", u, err) + } + + urls[index] = urlParsed.String() } + settings["siteURL"] = strings.Join(urls, ",") + return nil } -// ThumbGeneratorTestService 缩略图生成测试服务 -type ThumbGeneratorTestService struct { - Name string `json:"name" binding:"required"` - Executable string `json:"executable" binding:"required"` +func secretKeyPreProcessor(ctx context.Context, settings map[string]string) error { + settings["secret_key"] = util.RandStringRunes(256) + return nil } -// Test 通过获取生成器版本来测试 -func (s *ThumbGeneratorTestService) Test(c *gin.Context) serializer.Response { - version, err := thumb.TestGenerator(c, s.Name, s.Executable) - if err != nil { - return serializer.Err(serializer.CodeParamErr, err.Error(), err) +func mimeMappingPreProcessor(ctx context.Context, settings map[string]string) error { + var mapping map[string]string + if err := json.Unmarshal([]byte(settings["mime_mapping"]), &mapping); err != nil { + return serializer.NewError(serializer.CodeParamErr, "Invalid mime mapping", err) } - return serializer.Response{ - Data: version, - } + return nil +} + +func mimeMappingPostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.MimeDetector(context.WithValue(ctx, dependency.ReloadCtx{}, true)) + + return nil +} + +func mediaMetaPostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.MediaMetaExtractor(context.WithValue(ctx, dependency.ReloadCtx{}, true)) + return nil +} + +func emailPostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.EmailClient(context.WithValue(ctx, dependency.ReloadCtx{}, true)) + return nil +} + +func mediaMetaQueuePostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.MediaMetaQueue(context.WithValue(ctx, dependency.ReloadCtx{}, true)).Start() + return nil +} + +func ioIntenseQueuePostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.IoIntenseQueue(context.WithValue(ctx, dependency.ReloadCtx{}, true)).Start() + return nil +} + +func remoteDownloadQueuePostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.RemoteDownloadQueue(context.WithValue(ctx, dependency.ReloadCtx{}, true)).Start() + return nil +} + +func entityRecycleQueuePostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.EntityRecycleQueue(context.WithValue(ctx, dependency.ReloadCtx{}, true)).Start() + return nil +} + +func thumbQueuePostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.ThumbQueue(context.WithValue(ctx, dependency.ReloadCtx{}, true)).Start() + return nil +} + +func secretKeyPostProcessor(ctx context.Context, settings map[string]string) error { + dep := dependency.FromContext(ctx) + dep.KV().Delete(manager.EntityUrlCacheKeyPrefix) + settings["secret_key"] = "" + return nil } diff --git a/service/admin/task.go b/service/admin/task.go index 2146d467..07c22022 100644 --- a/service/admin/task.go +++ b/service/admin/task.go @@ -1,159 +1,251 @@ package admin import ( - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" + "context" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "github.com/samber/lo" ) -// TaskBatchService 任务批量操作服务 -type TaskBatchService struct { - ID []uint `json:"id" binding:"min=1"` +func GetQueueMetrics(c *gin.Context) ([]QueueMetric, error) { + res := []QueueMetric{} + dep := dependency.FromContext(c) + + mediaMeta := dep.MediaMetaQueue(c) + entityRecycle := dep.EntityRecycleQueue(c) + ioIntense := dep.IoIntenseQueue(c) + remoteDownload := dep.RemoteDownloadQueue(c) + thumb := dep.ThumbQueue(c) + + res = append(res, QueueMetric{ + Name: setting.QueueTypeMediaMeta, + BusyWorkers: mediaMeta.BusyWorkers(), + SuccessTasks: mediaMeta.SuccessTasks(), + FailureTasks: mediaMeta.FailureTasks(), + SubmittedTasks: mediaMeta.SubmittedTasks(), + SuspendingTasks: mediaMeta.SuspendingTasks(), + }) + res = append(res, QueueMetric{ + Name: setting.QueueTypeEntityRecycle, + BusyWorkers: entityRecycle.BusyWorkers(), + SuccessTasks: entityRecycle.SuccessTasks(), + FailureTasks: entityRecycle.FailureTasks(), + SubmittedTasks: entityRecycle.SubmittedTasks(), + SuspendingTasks: entityRecycle.SuspendingTasks(), + }) + res = append(res, QueueMetric{ + Name: setting.QueueTypeIOIntense, + BusyWorkers: ioIntense.BusyWorkers(), + SuccessTasks: ioIntense.SuccessTasks(), + FailureTasks: ioIntense.FailureTasks(), + SubmittedTasks: ioIntense.SubmittedTasks(), + SuspendingTasks: ioIntense.SuspendingTasks(), + }) + res = append(res, QueueMetric{ + Name: setting.QueueTypeRemoteDownload, + BusyWorkers: remoteDownload.BusyWorkers(), + SuccessTasks: remoteDownload.SuccessTasks(), + FailureTasks: remoteDownload.FailureTasks(), + SubmittedTasks: remoteDownload.SubmittedTasks(), + SuspendingTasks: remoteDownload.SuspendingTasks(), + }) + res = append(res, QueueMetric{ + Name: setting.QueueTypeThumb, + BusyWorkers: thumb.BusyWorkers(), + SuccessTasks: thumb.SuccessTasks(), + FailureTasks: thumb.FailureTasks(), + SubmittedTasks: thumb.SubmittedTasks(), + SuspendingTasks: thumb.SuspendingTasks(), + }) + + return res, nil } -// ImportTaskService 导入任务 -type ImportTaskService struct { - UID uint `json:"uid" binding:"required"` - PolicyID uint `json:"policy_id" binding:"required"` - Src string `json:"src" binding:"required,min=1,max=65535"` - Dst string `json:"dst" binding:"required,min=1,max=65535"` - Recursive bool `json:"recursive"` -} +const ( + taskTypeCondition = "task_type" + taskStatusCondition = "task_status" + taskCorrelationIDCondition = "task_correlation_id" + taskUserIDCondition = "task_user_id" +) -// Create 新建导入任务 -func (service *ImportTaskService) Create(c *gin.Context, user *model.User) serializer.Response { - // 创建任务 - job, err := task.NewImportTask(service.UID, service.PolicyID, service.Src, service.Dst, service.Recursive) - if err != nil { - return serializer.DBErr("Failed to create task record.", err) +func (s *AdminListService) Tasks(c *gin.Context) (*ListTaskResponse, error) { + dep := dependency.FromContext(c) + taskClient := dep.TaskClient() + hasher := dep.HashIDEncoder() + var ( + err error + userID int + correlationID *uuid.UUID + status []task.Status + taskType []string + ) + + if s.Conditions[taskTypeCondition] != "" { + taskType = []string{s.Conditions[taskTypeCondition]} } - task.TaskPoll.Submit(job) - return serializer.Response{} -} -// Delete 删除任务 -func (service *TaskBatchService) Delete(c *gin.Context) serializer.Response { - if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Download{}).Error; err != nil { - return serializer.DBErr("Failed to delete task records", err) + if s.Conditions[taskStatusCondition] != "" { + status = []task.Status{task.Status(s.Conditions[taskStatusCondition])} } - return serializer.Response{} -} -// DeleteGeneral 删除常规任务 -func (service *TaskBatchService) DeleteGeneral(c *gin.Context) serializer.Response { - if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Task{}).Error; err != nil { - return serializer.DBErr("Failed to delete task records", err) + if s.Conditions[taskCorrelationIDCondition] != "" { + cid, err := uuid.FromString(s.Conditions[taskCorrelationIDCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid task correlation ID", err) + } + correlationID = &cid } - return serializer.Response{} -} -// Tasks 列出常规任务 -func (service *AdminListService) Tasks() serializer.Response { - var res []model.Task - total := 0 - - tx := model.DB.Model(&model.Task{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) + if s.Conditions[taskUserIDCondition] != "" { + userID, err = strconv.Atoi(s.Conditions[taskUserIDCondition]) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid task user ID", err) + } } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) - } + ctx := context.WithValue(c, inventory.LoadTaskUser{}, true) + res, err := taskClient.List(ctx, &inventory.ListTaskArgs{ + PaginationArgs: &inventory.PaginationArgs{ + Page: s.Page - 1, + PageSize: s.PageSize, + OrderBy: s.OrderBy, + Order: inventory.OrderDirection(s.OrderDirection), + }, + UserID: userID, + CorrelationID: correlationID, + Types: taskType, + Status: status, + }) - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += k + " like '%" + v + "%' OR " - } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list tasks", err) } - // 计算总数用于分页 - tx.Count(&total) - - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + tasks := make([]queue.Task, 0, len(res.Tasks)) + nodeMap := make(map[int]*ent.Node) + for _, t := range res.Tasks { + task, err := queue.NewTaskFromModel(t) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to parse task", err) + } - // 查询对应用户,同时计算HashID - users := make(map[uint]model.User) - for _, file := range res { - users[file.UserID] = model.User{} + summary := task.Summarize(hasher) + if summary != nil && summary.NodeID > 0 { + if _, ok := nodeMap[summary.NodeID]; !ok { + nodeMap[summary.NodeID] = nil + } + } + tasks = append(tasks, task) } - userIDs := make([]uint, 0, len(users)) - for k := range users { - userIDs = append(userIDs, k) + // Get nodes + nodes, err := dep.NodeClient().GetNodeByIds(c, lo.Keys(nodeMap)) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to query nodes", err) } - - var userList []model.User - model.DB.Where("id in (?)", userIDs).Find(&userList) - - for _, v := range userList { - users[v.ID] = v + for _, n := range nodes { + nodeMap[n.ID] = n } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "users": users, - }} + return &ListTaskResponse{ + Pagination: res.PaginationResults, + Tasks: lo.Map(res.Tasks, func(task *ent.Task, i int) GetTaskResponse { + var ( + uid string + node *ent.Node + summary *queue.Summary + ) + + if task.Edges.User != nil { + uid = hashid.EncodeUserID(hasher, task.Edges.User.ID) + } + + t := tasks[i] + summary = t.Summarize(hasher) + if summary != nil && summary.NodeID > 0 { + node = nodeMap[summary.NodeID] + } + + return GetTaskResponse{ + Task: task, + UserHashID: uid, + Node: node, + Summary: summary, + } + }), + }, nil } -// Downloads 列出离线下载任务 -func (service *AdminListService) Downloads() serializer.Response { - var res []model.Download - total := 0 - - tx := model.DB.Model(&model.Download{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) +type ( + SingleTaskService struct { + ID int `uri:"id" json:"id" binding:"required"` } + SingleTaskParamCtx struct{} +) - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) +func (s *SingleTaskService) Get(c *gin.Context) (*GetTaskResponse, error) { + dep := dependency.FromContext(c) + taskClient := dep.TaskClient() + hasher := dep.HashIDEncoder() + + ctx := context.WithValue(c, inventory.LoadTaskUser{}, true) + task, err := taskClient.GetTaskByID(ctx, s.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get task", err) } - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += k + " like '%" + v + "%' OR " - } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + t, err := queue.NewTaskFromModel(task) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to parse task", err) } - // 计算总数用于分页 - tx.Count(&total) + summary := t.Summarize(hasher) + var ( + node *ent.Node + userHashID string + ) - // 查询记录 - tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + if summary != nil && summary.NodeID > 0 { + node, _ = dep.NodeClient().GetNodeById(c, summary.NodeID) + } - // 查询对应用户,同时计算HashID - users := make(map[uint]model.User) - for _, file := range res { - users[file.UserID] = model.User{} + if task.Edges.User != nil { + userHashID = hashid.EncodeUserID(hasher, task.Edges.User.ID) } - userIDs := make([]uint, 0, len(users)) - for k := range users { - userIDs = append(userIDs, k) + return &GetTaskResponse{ + Task: task, + Summary: summary, + Node: node, + UserHashID: userHashID, + }, nil +} + +type ( + BatchTaskService struct { + IDs []int `json:"ids" binding:"required"` } + BatchTaskParamCtx struct{} +) - var userList []model.User - model.DB.Where("id in (?)", userIDs).Find(&userList) +func (s *BatchTaskService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + taskClient := dep.TaskClient() - for _, v := range userList { - users[v.ID] = v + err := taskClient.DeleteByIDs(c, s.IDs...) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete tasks", err) } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - "users": users, - }} + return nil } diff --git a/service/admin/tools.go b/service/admin/tools.go new file mode 100644 index 00000000..20dc90f1 --- /dev/null +++ b/service/admin/tools.go @@ -0,0 +1,170 @@ +package admin + +import ( + "encoding/hex" + "net/http" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + request2 "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/wopi" + "github.com/gin-gonic/gin" + "github.com/go-mail/mail" +) + +type ( + HashIDService struct { + ID int `json:"id"` + Type int `json:"type"` + HashID string `json:"hash_id"` + } + HashIDParamCtx struct{} +) + +func (service *HashIDService) Encode(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + res, err := dep.HashIDEncoder().Encode([]int{service.ID, service.Type}) + if err != nil { + return "", err + } + return res, nil +} + +func (service *HashIDService) Decode(c *gin.Context) (int, error) { + dep := dependency.FromContext(c) + res, err := dep.HashIDEncoder().Decode(service.HashID, service.Type) + if err != nil { + return 0, err + } + + return res, nil +} + +type ( + BsEncodeService struct { + Bool []int `json:"bool"` + } + BsEncodeParamCtx struct{} + BsEncodeRes struct { + Hex string + B64 []byte + } +) + +func (service *BsEncodeService) Encode(c *gin.Context) (*BsEncodeRes, error) { + bs := &boolset.BooleanSet{} + for _, v := range service.Bool { + boolset.Set(v, true, bs) + } + + res, err := bs.MarshalBinary() + if err != nil { + return nil, err + } + + return &BsEncodeRes{ + Hex: hex.EncodeToString(res), + B64: res, + }, nil +} + +type ( + BsDecodeService struct { + Code string `json:"code"` + } + BsDecodeParamCtx struct{} + BsDecodeRes struct { + Bool []int `json:"bool"` + } +) + +func (service *BsDecodeService) Decode(c *gin.Context) (*BsDecodeRes, error) { + bs, err := boolset.FromString(service.Code) + if err != nil { + return nil, err + } + + res := []int{} + for i := 0; i < len(*bs)*8; i++ { + if bs.Enabled(i) { + res = append(res, i) + } + } + + return &BsDecodeRes{ + Bool: res, + }, nil +} + +type ( + FetchWOPIDiscoveryService struct { + Endpoint string `form:"endpoint" binding:"required"` + } + FetchWOPIDiscoveryParamCtx struct{} +) + +func (s *FetchWOPIDiscoveryService) Fetch(c *gin.Context) (*setting.ViewerGroup, error) { + dep := dependency.FromContext(c) + requestClient := dep.RequestClient(request2.WithContext(c), request2.WithLogger(dep.Logger())) + content, err := requestClient.Request("GET", s.Endpoint, nil).CheckHTTPResponse(http.StatusOK).GetResponse() + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "WOPI endpoint id unavailable", err) + } + + vg, err := wopi.DiscoveryXmlToViewerGroup(content) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Failed to parse WOPI response", err) + } + + return vg, nil +} + +type ( + TestSMTPService struct { + Settings map[string]string `json:"settings" binding:"required"` + To string `json:"to" binding:"required,email"` + } + TestSMTPParamCtx struct{} +) + +func (s *TestSMTPService) Test(c *gin.Context) error { + port, err := strconv.Atoi(s.Settings["smtpPort"]) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "Invalid SMTP port", err) + } + + d := mail.NewDialer(s.Settings["smtpHost"], port, s.Settings["smtpUser"], s.Settings["smtpPass"]) + d.SSL = false + if setting.IsTrueValue(s.Settings["smtpEncryption"]) { + d.SSL = true + } + d.StartTLSPolicy = mail.OpportunisticStartTLS + + sender, err := d.Dial() + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to connect to SMTP server: "+err.Error(), err) + } + + m := mail.NewMessage() + m.SetHeader("From", s.Settings["fromAdress"]) + m.SetAddressHeader("Reply-To", s.Settings["replyTo"], s.Settings["fromName"]) + m.SetHeader("To", s.To) + m.SetHeader("Subject", "Cloudreve SMTP Test") + m.SetBody("text/plain", "This is a test email from Cloudreve.") + + err = mail.Send(sender, m) + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to send test email: "+err.Error(), err) + } + + return nil +} + +func ClearEntityUrlCache(c *gin.Context) { + dep := dependency.FromContext(c) + dep.KV().Delete(manager.EntityUrlCacheKeyPrefix) +} diff --git a/service/admin/user.go b/service/admin/user.go index eb76ac9a..5ddc0670 100644 --- a/service/admin/user.go +++ b/service/admin/user.go @@ -2,17 +2,24 @@ package admin import ( "context" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" + "github.com/samber/lo" ) // AddUserService 用户添加服务 type AddUserService struct { - User model.User `json:"User" binding:"required"` - Password string `json:"password"` + //User model.User `json:"User" binding:"required"` + Password string `json:"password"` } // UserService 用户ID服务 @@ -25,148 +32,219 @@ type UserBatchService struct { ID []uint `json:"id" binding:"min=1"` } -// Ban 封禁/解封用户 -func (service *UserService) Ban() serializer.Response { - user, err := model.GetUserByID(service.ID) - if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) - } - - if user.ID == 1 { - return serializer.Err(serializer.CodeInvalidActionOnDefaultUser, "", err) - } +const ( + userStatusCondition = "user_status" + userGroupCondition = "user_group" + userNickCondition = "user_nick" + userEmailCondition = "user_email" +) - if user.Status == model.Active { - user.SetStatus(model.Baned) - } else { - user.SetStatus(model.Active) - } +func (service *AdminListService) Users(c *gin.Context) (*ListUserResponse, error) { + dep := dependency.FromContext(c) + hasher := dep.HashIDEncoder() + userClient := dep.UserClient() - return serializer.Response{Data: user.Status} -} + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserPasskey{}, true) -// Delete 删除用户 -func (service *UserBatchService) Delete() serializer.Response { - for _, uid := range service.ID { - user, err := model.GetUserByID(uid) + var ( + err error + groupID int + ) + if service.Conditions[userGroupCondition] != "" { + groupID, err = strconv.Atoi(service.Conditions[userGroupCondition]) if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid group ID", err) } + } - // 不能删除初始用户 - if uid == 1 { - return serializer.Err(serializer.CodeInvalidActionOnDefaultUser, "", err) - } + res, err := userClient.ListUsers(ctx, &inventory.ListUserParameters{ + PaginationArgs: &inventory.PaginationArgs{ + Page: service.Page - 1, + PageSize: service.PageSize, + OrderBy: service.OrderBy, + Order: inventory.OrderDirection(service.OrderDirection), + }, + Status: user.Status(service.Conditions[userStatusCondition]), + GroupID: groupID, + Nick: service.Conditions[userNickCondition], + Email: service.Conditions[userEmailCondition], + }) - // 删除与此用户相关的所有资源 + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list users", err) + } - fs, err := filesystem.NewFileSystem(&user) - // 删除所有文件 - root, err := fs.User.Root() - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "User's root folder not exist", err) - } - fs.Delete(context.Background(), []uint{root.ID}, []uint{}, false, false) + return &ListUserResponse{ + Pagination: res.PaginationResults, + Users: lo.Map(res.Users, func(user *ent.User, _ int) GetUserResponse { + return GetUserResponse{ + User: user, + HashID: hashid.EncodeUserID(hasher, user.ID), + TwoFAEnabled: user.TwoFactorSecret != "", + } + }), + }, nil +} - // 删除相关任务 - model.DB.Where("user_id = ?", uid).Delete(&model.Download{}) - model.DB.Where("user_id = ?", uid).Delete(&model.Task{}) +type ( + SingleUserService struct { + ID int `uri:"id" json:"id" binding:"required"` + } + SingleUserParamCtx struct{} +) - // 删除标签 - model.DB.Where("user_id = ?", uid).Delete(&model.Tag{}) +func (service *SingleUserService) Get(c *gin.Context) (*GetUserResponse, error) { + dep := dependency.FromContext(c) + hasher := dep.HashIDEncoder() + userClient := dep.UserClient() - // 删除WebDAV账号 - model.DB.Where("user_id = ?", uid).Delete(&model.Webdav{}) + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserPasskey{}, true) - // 删除此用户 - model.DB.Unscoped().Delete(user) + user, err := userClient.GetByID(ctx, service.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user", err) + } + m := manager.NewFileManager(dep, user) + capacity, err := m.Capacity(ctx) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to get user capacity", err) } - return serializer.Response{} + + return &GetUserResponse{ + User: user, + HashID: hashid.EncodeUserID(hasher, user.ID), + TwoFAEnabled: user.TwoFactorSecret != "", + Capacity: capacity, + }, nil } -// Get 获取用户详情 -func (service *UserService) Get() serializer.Response { - group, err := model.GetUserByID(service.ID) +func (service *SingleUserService) CalibrateStorage(c *gin.Context) (*GetUserResponse, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + _, err := userClient.CalculateStorage(ctx, service.ID) if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to calculate storage", err) } - return serializer.Response{Data: group} + subService := &SingleUserService{ID: service.ID} + return subService.Get(c) } -// Add 添加用户 -func (service *AddUserService) Add() serializer.Response { - if service.User.ID > 0 { +type ( + UpsertUserService struct { + User *ent.User `json:"user" binding:"required"` + Password string `json:"password"` + TwoFA string `json:"two_fa"` + } + UpsertUserParamCtx struct{} +) - user, _ := model.GetUserByID(service.User.ID) - if service.Password != "" { - user.SetPassword(service.Password) - } +func (s *UpsertUserService) Update(c *gin.Context) (*GetUserResponse, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() - // 只更新必要字段 - user.Nick = service.User.Nick - user.Email = service.User.Email - user.GroupID = service.User.GroupID - user.Status = service.User.Status - user.TwoFactor = service.User.TwoFactor - - // 检查愚蠢操作 - if user.ID == 1 { - if user.GroupID != 1 { - return serializer.Err(serializer.CodeChangeGroupForDefaultUser, "", nil) - } - if user.Status != model.Active { - return serializer.Err(serializer.CodeInvalidActionOnDefaultUser, "", nil) - } - } + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + existing, err := userClient.GetByID(ctx, s.User.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user", err) + } - if err := model.DB.Save(&user).Error; err != nil { - return serializer.DBErr("Failed to save user record", err) + if s.User.ID == 1 && existing.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) { + if s.User.GroupUsers != existing.GroupUsers { + return nil, serializer.NewError(serializer.CodeInvalidActionOnDefaultUser, "Cannot change default user's group", nil) } - } else { - service.User.SetPassword(service.Password) - if err := model.DB.Create(&service.User).Error; err != nil { - return serializer.DBErr("Failed to create user record", err) + + if s.User.Status != user.StatusActive { + return nil, serializer.NewError(serializer.CodeInvalidActionOnDefaultUser, "Cannot change default user's status", nil) } + + } + + newUser, err := userClient.Upsert(ctx, s.User, s.Password, s.TwoFA) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update user", err) } - return serializer.Response{Data: service.User.ID} + service := &SingleUserService{ID: newUser.ID} + return service.Get(c) } -// Users 列出用户 -func (service *AdminListService) Users() serializer.Response { - var res []model.User - total := 0 +func (s *UpsertUserService) Create(c *gin.Context) (*GetUserResponse, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() - tx := model.DB.Model(&model.User{}) - if service.OrderBy != "" { - tx = tx.Order(service.OrderBy) + if s.Password == "" { + return nil, serializer.NewError(serializer.CodeParamErr, "Password is required", nil) } - for k, v := range service.Conditions { - tx = tx.Where(k+" = ?", v) + if s.User.ID != 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "ID must be 0", nil) } - if len(service.Searches) > 0 { - search := "" - for k, v := range service.Searches { - search += (k + " like '%" + v + "%' OR ") - } - search = strings.TrimSuffix(search, " OR ") - tx = tx.Where(search) + user, err := userClient.Upsert(c, s.User, s.Password, s.TwoFA) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create user", err) + } + + service := &SingleUserService{ID: user.ID} + return service.Get(c) + +} + +type ( + BatchUserService struct { + IDs []int `json:"ids" binding:"min=1"` } + BatchUserParamCtx struct{} +) + +func (s *BatchUserService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + fileClient := dep.FileClient() + + current := inventory.UserFromContext(c) + ae := serializer.NewAggregateError() + for _, id := range s.IDs { + if current.ID == id || id == 1 { + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeInvalidActionOnDefaultUser, "Cannot delete current user", nil)) + continue + } + + fc, tx, ctx, err := inventory.WithTx(c, fileClient) + if err != nil { + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)) + continue + } - // 计算总数用于分页 - tx.Count(&total) + uc, _, ctx, err := inventory.WithTx(ctx, userClient) + if err != nil { + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)) + continue + } - // 查询记录 - tx.Set("gorm:auto_preload", true).Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + if err := fc.DeleteByUser(ctx, id); err != nil { + _ = inventory.Rollback(tx) + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeDBError, "Failed to delete user files", err)) + continue + } - // 补齐缺失用户组 + if err := uc.Delete(ctx, id); err != nil { + _ = inventory.Rollback(tx) + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeDBError, "Failed to delete user", err)) + continue + } + + if err := inventory.Commit(tx); err != nil { + ae.Add(strconv.Itoa(id), serializer.NewError(serializer.CodeDBError, "Failed to commit transaction", err)) + continue + } + } - return serializer.Response{Data: map[string]interface{}{ - "total": total, - "items": res, - }} + return ae.Aggregate() } diff --git a/service/aria2/add.go b/service/aria2/add.go deleted file mode 100644 index 816c57bb..00000000 --- a/service/aria2/add.go +++ /dev/null @@ -1,151 +0,0 @@ -package aria2 - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" -) - -// AddURLService 添加URL离线下载服务 -type BatchAddURLService struct { - URLs []string `json:"url" binding:"required"` - Dst string `json:"dst" binding:"required,min=1"` -} - -// Add 主机批量创建新的链接离线下载任务 -func (service *BatchAddURLService) Add(c *gin.Context, taskType int) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.Aria2 { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - // 存放目录是否存在 - if exist, _ := fs.IsPathExist(service.Dst); !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - - // 检查批量任务数量 - limit := fs.User.Group.OptionsSerialized.Aria2BatchSize - if limit > 0 && len(service.URLs) > limit { - return serializer.Err(serializer.CodeBatchAria2Size, "", nil) - } - - res := make([]serializer.Response, 0, len(service.URLs)) - for _, target := range service.URLs { - subService := &AddURLService{ - URL: target, - Dst: service.Dst, - } - - addRes := subService.Add(c, fs, taskType) - res = append(res, addRes) - } - - return serializer.Response{Data: res} -} - -// AddURLService 添加URL离线下载服务 -type AddURLService struct { - URL string `json:"url" binding:"required"` - Dst string `json:"dst" binding:"required,min=1"` -} - -// Add 主机创建新的链接离线下载任务 -func (service *AddURLService) Add(c *gin.Context, fs *filesystem.FileSystem, taskType int) serializer.Response { - if fs == nil { - var err error - // 创建文件系统 - fs, err = filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.Aria2 { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - // 存放目录是否存在 - if exist, _ := fs.IsPathExist(service.Dst); !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - } - - downloads := model.GetDownloadsByStatusAndUser(0, fs.User.ID, common.Downloading, common.Paused, common.Ready) - limit := fs.User.Group.OptionsSerialized.Aria2BatchSize - if limit > 0 && len(downloads)+1 > limit { - return serializer.Err(serializer.CodeBatchAria2Size, "", nil) - } - - // 创建任务 - task := &model.Download{ - Status: common.Ready, - Type: taskType, - Dst: service.Dst, - UserID: fs.User.ID, - Source: service.URL, - } - - // 获取 Aria2 负载均衡器 - lb := aria2.GetLoadBalancer() - - // 获取 Aria2 实例 - err, node := cluster.Default.BalanceNodeByFeature("aria2", lb) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to get Aria2 instance", err) - } - - // 创建任务 - gid, err := node.GetAria2Instance().CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options) - if err != nil { - return serializer.Err(serializer.CodeCreateTaskError, "", err) - } - - task.GID = gid - task.NodeID = node.ID() - _, err = task.Create() - if err != nil { - return serializer.DBErr("Failed to create task record", err) - } - - // 创建任务监控 - monitor.NewMonitor(task, cluster.Default, mq.GlobalMQ) - - return serializer.Response{} -} - -// Add 从机创建新的链接离线下载任务 -func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { - caller, _ := c.Get("MasterAria2Instance") - - // 创建任务 - gid, err := caller.(common.Aria2).CreateTask(service.Task, service.GroupOptions) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to create aria2 task", err) - } - - // 创建事件通知回调 - siteID, _ := c.Get("MasterSiteID") - mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) { - if err := cluster.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { - util.Log().Warning("Failed to send remote download task status change notifications: %s", err) - } - }) - - return serializer.Response{Data: gid} -} diff --git a/service/aria2/manage.go b/service/aria2/manage.go deleted file mode 100644 index 35ccdff0..00000000 --- a/service/aria2/manage.go +++ /dev/null @@ -1,172 +0,0 @@ -package aria2 - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" -) - -// SelectFileService 选择要下载的文件服务 -type SelectFileService struct { - Indexes []int `json:"indexes" binding:"required"` -} - -// DownloadTaskService 下载任务管理服务 -type DownloadTaskService struct { - GID string `uri:"gid" binding:"required"` -} - -// DownloadListService 下载列表服务 -type DownloadListService struct { - Page uint `form:"page"` -} - -// Finished 获取已完成的任务 -func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { - // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown) - for key, download := range downloads { - node := cluster.Default.GetNodeByID(download.GetNodeID()) - if node != nil { - downloads[key].NodeName = node.DBModel().Name - } - } - - return serializer.BuildFinishedListResponse(downloads) -} - -// Downloading 获取正在下载中的任务 -func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { - // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) - intervals := make(map[uint]int) - for key, download := range downloads { - if _, ok := intervals[download.ID]; !ok { - if node := cluster.Default.GetNodeByID(download.GetNodeID()); node != nil { - intervals[download.ID] = node.DBModel().Aria2OptionsSerialized.Interval - } - } - - node := cluster.Default.GetNodeByID(download.GetNodeID()) - if node != nil { - downloads[key].NodeName = node.DBModel().Name - } - } - - return serializer.BuildDownloadingResponse(downloads, intervals) -} - -// Delete 取消或删除下载任务 -func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) - - // 查找下载记录 - download, err := model.GetDownloadByGid(c.Param("gid"), user.ID) - if err != nil { - return serializer.Err(serializer.CodeNotFound, "Download record not found", err) - } - - if download.Status >= common.Error && download.Status <= common.Unknown { - // 如果任务已完成,则删除任务记录 - if err := download.Delete(); err != nil { - return serializer.DBErr("Failed to delete task record", err) - } - return serializer.Response{} - } - - // 取消任务 - node := cluster.Default.GetNodeByID(download.GetNodeID()) - if node == nil { - return serializer.Err(serializer.CodeNodeOffline, "", err) - } - - if err := node.GetAria2Instance().Cancel(download); err != nil { - return serializer.Err(serializer.CodeNotSet, "Operation failed", err) - } - - return serializer.Response{} -} - -// Select 选取要下载的文件 -func (service *SelectFileService) Select(c *gin.Context) serializer.Response { - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) - - // 查找下载记录 - download, err := model.GetDownloadByGid(c.Param("gid"), user.ID) - if err != nil { - return serializer.Err(serializer.CodeNotFound, "Download record not found", err) - } - - if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != common.Downloading && download.Status != common.Paused) { - return serializer.ParamErr("You cannot select files for this task", nil) - } - - // 选取下载 - node := cluster.Default.GetNodeByID(download.GetNodeID()) - if err := node.GetAria2Instance().Select(download, service.Indexes); err != nil { - return serializer.Err(serializer.CodeNotSet, "Operation failed", err) - } - - return serializer.Response{} - -} - -// SlaveStatus 从机查询离线任务状态 -func SlaveStatus(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { - caller, _ := c.Get("MasterAria2Instance") - - // 查询任务 - status, err := caller.(common.Aria2).Status(service.Task) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to query remote download task status", err) - } - - return serializer.NewResponseWithGobData(status) - -} - -// SlaveCancel 取消从机离线下载任务 -func SlaveCancel(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { - caller, _ := c.Get("MasterAria2Instance") - - // 查询任务 - err := caller.(common.Aria2).Cancel(service.Task) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to cancel task", err) - } - - return serializer.Response{} - -} - -// SlaveSelect 从机选取离线下载任务文件 -func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { - caller, _ := c.Get("MasterAria2Instance") - - // 查询任务 - err := caller.(common.Aria2).Select(service.Task, service.Files) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to select files", err) - } - - return serializer.Response{} - -} - -// SlaveSelect 从机选取离线下载任务文件 -func SlaveDeleteTemp(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { - caller, _ := c.Get("MasterAria2Instance") - - // 查询任务 - err := caller.(common.Aria2).DeleteTempFile(service.Task) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to delete temp files", err) - } - - return serializer.Response{} - -} diff --git a/service/basic/site.go b/service/basic/site.go new file mode 100644 index 00000000..814cfea3 --- /dev/null +++ b/service/basic/site.go @@ -0,0 +1,178 @@ +package basic + +import ( + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/service/user" + "github.com/gin-gonic/gin" + "github.com/mojocn/base64Captcha" +) + +// SiteConfig 站点全局设置序列 +type SiteConfig struct { + // Basic Section + InstanceID string `json:"instance_id,omitempty"` + SiteName string `json:"title,omitempty"` + Themes string `json:"themes,omitempty"` + DefaultTheme string `json:"default_theme,omitempty"` + User *user.User `json:"user,omitempty"` + Logo string `json:"logo,omitempty"` + LogoLight string `json:"logo_light,omitempty"` + + // Login Section + LoginCaptcha bool `json:"login_captcha,omitempty"` + RegCaptcha bool `json:"reg_captcha,omitempty"` + ForgetCaptcha bool `json:"forget_captcha,omitempty"` + Authn bool `json:"authn,omitempty"` + ReCaptchaKey string `json:"captcha_ReCaptchaKey,omitempty"` + CaptchaType setting.CaptchaType `json:"captcha_type,omitempty"` + TurnstileSiteID string `json:"turnstile_site_id,omitempty"` + RegisterEnabled bool `json:"register_enabled,omitempty"` + TosUrl string `json:"tos_url,omitempty"` + PrivacyPolicyUrl string `json:"privacy_policy_url,omitempty"` + + // Explorer section + Icons string `json:"icons,omitempty"` + EmojiPreset string `json:"emoji_preset,omitempty"` + MapProvider setting.MapProvider `json:"map_provider,omitempty"` + GoogleMapTileType setting.MapGoogleTileType `json:"google_map_tile_type,omitempty"` + FileViewers []setting.ViewerGroup `json:"file_viewers,omitempty"` + MaxBatchSize int `json:"max_batch_size,omitempty"` + ThumbnailWidth int `json:"thumbnail_width,omitempty"` + ThumbnailHeight int `json:"thumbnail_height,omitempty"` + + // App settings + AppPromotion bool `json:"app_promotion,omitempty"` + + //EmailActive bool `json:"emailActive"` + //QQLogin bool `json:"QQLogin"` + //ScoreEnabled bool `json:"score_enabled"` + //ShareScoreRate string `json:"share_score_rate"` + //HomepageViewMethod string `json:"home_view_method"` + //ShareViewMethod string `json:"share_view_method"` + //WopiExts []string `json:"wopi_exts"` + //AppFeedbackLink string `json:"app_feedback"` + //AppForumLink string `json:"app_forum"` +} + +type ( + GetSettingService struct { + Section string `uri:"section" binding:"required"` + } + GetSettingParamCtx struct{} +) + +func (s *GetSettingService) GetSiteConfig(c *gin.Context) (*SiteConfig, error) { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + + switch s.Section { + case "login": + legalDocs := settings.LegalDocuments(c) + return &SiteConfig{ + LoginCaptcha: settings.LoginCaptchaEnabled(c), + RegCaptcha: settings.RegCaptchaEnabled(c), + ForgetCaptcha: settings.ForgotPasswordCaptchaEnabled(c), + Authn: settings.AuthnEnabled(c), + RegisterEnabled: settings.RegisterEnabled(c), + PrivacyPolicyUrl: legalDocs.PrivacyPolicy, + TosUrl: legalDocs.TermsOfService, + }, nil + case "explorer": + explorerSettings := settings.ExplorerFrontendSettings(c) + mapSettings := settings.MapSetting(c) + fileViewers := settings.FileViewers(c) + maxBatchSize := settings.MaxBatchedFile(c) + w, h := settings.ThumbSize(c) + for i := range fileViewers { + for j := range fileViewers[i].Viewers { + fileViewers[i].Viewers[j].WopiActions = nil + } + } + return &SiteConfig{ + MaxBatchSize: maxBatchSize, + FileViewers: fileViewers, + Icons: explorerSettings.Icons, + MapProvider: mapSettings.Provider, + GoogleMapTileType: mapSettings.GoogleTileType, + ThumbnailWidth: w, + ThumbnailHeight: h, + }, nil + case "emojis": + emojis := settings.EmojiPresets(c) + return &SiteConfig{ + EmojiPreset: emojis, + }, nil + case "app": + appSetting := settings.AppSetting(c) + return &SiteConfig{ + AppPromotion: appSetting.Promotion, + }, nil + default: + break + } + + u := inventory.UserFromContext(c) + siteBasic := settings.SiteBasic(c) + themes := settings.Theme(c) + userRes := user.BuildUser(u, dep.HashIDEncoder()) + logo := settings.Logo(c) + reCaptcha := settings.ReCaptcha(c) + appSetting := settings.AppSetting(c) + + return &SiteConfig{ + InstanceID: siteBasic.ID, + SiteName: siteBasic.Name, + Themes: themes.Themes, + DefaultTheme: themes.DefaultTheme, + User: &userRes, + Logo: logo.Normal, + LogoLight: logo.Light, + CaptchaType: settings.CaptchaType(c), + TurnstileSiteID: settings.TurnstileCaptcha(c).Key, + ReCaptchaKey: reCaptcha.Key, + AppPromotion: appSetting.Promotion, + }, nil +} + +const ( + CaptchaSessionPrefix = "captcha_session_" + CaptchaTTL = 1800 // 30 minutes +) + +type ( + CaptchaResponse struct { + Image string `json:"image"` + Ticket string `json:"ticket"` + } +) + +// GetCaptchaImage generates captcha session +func GetCaptchaImage(c *gin.Context) *CaptchaResponse { + dep := dependency.FromContext(c) + captchaSettings := dep.SettingProvider().Captcha(c) + var configD = base64Captcha.ConfigCharacter{ + Height: captchaSettings.Height, + Width: captchaSettings.Width, + Mode: int(captchaSettings.Mode), + ComplexOfNoiseText: captchaSettings.ComplexOfNoiseText, + ComplexOfNoiseDot: captchaSettings.ComplexOfNoiseDot, + IsShowHollowLine: captchaSettings.IsShowHollowLine, + IsShowNoiseDot: captchaSettings.IsShowNoiseDot, + IsShowNoiseText: captchaSettings.IsShowNoiseText, + IsShowSlimeLine: captchaSettings.IsShowSlimeLine, + IsShowSineLine: captchaSettings.IsShowSineLine, + CaptchaLen: captchaSettings.Length, + } + + // 生成验证码 + idKeyD, capD := base64Captcha.GenerateCaptcha("", configD) + + base64stringD := base64Captcha.CaptchaWriteToBase64Encoding(capD) + + return &CaptchaResponse{ + Image: base64stringD, + Ticket: idKeyD, + } +} diff --git a/service/callback/oauth.go b/service/callback/oauth.go index f93636bb..3f9b6f3a 100644 --- a/service/callback/oauth.go +++ b/service/callback/oauth.go @@ -1,18 +1,8 @@ package callback import ( - "context" - "fmt" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" - "github.com/samber/lo" - "strings" ) // OauthService OAuth 存储策略授权回调服务 @@ -23,110 +13,111 @@ type OauthService struct { Scope string `form:"scope"` } -// GDriveAuth Google Drive 更新认证信息 -func (service *OauthService) GDriveAuth(c *gin.Context) serializer.Response { - if service.Error != "" { - return serializer.ParamErr(service.Error, nil) - } - - // validate required scope - if missing, found := lo.Find[string](googledrive.RequiredScope, func(item string) bool { - return !strings.Contains(service.Scope, item) - }); found { - return serializer.ParamErr(fmt.Sprintf("Missing required scope: %s", missing), nil) - } - - policyID, ok := util.GetSession(c, "googledrive_oauth_policy").(uint) - if !ok { - return serializer.Err(serializer.CodeNotFound, "", nil) - } - - util.DeleteSession(c, "googledrive_oauth_policy") - - policy, err := model.GetPolicyByID(policyID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", nil) - } - - client, err := googledrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) - } - - credential, err := client.ObtainToken(c, service.Code, "") - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) - } - - // 更新存储策略的 RefreshToken - client.Policy.AccessKey = credential.RefreshToken - if err := client.Policy.SaveAndClearCache(); err != nil { - return serializer.DBErr("Failed to update RefreshToken", err) - } - - cache.Deletes([]string{client.Policy.AccessKey}, googledrive.TokenCachePrefix) - return serializer.Response{} -} +// +//// GDriveAuth Google Drive 更新认证信息 +//func (service *OauthService) GDriveAuth(c *gin.Context) serializer.Response { +// if service.Error != "" { +// return serializer.ParamErrDeprecated(service.Error, nil) +// } +// +// // validate required scope +// if missing, found := lo.Find[string](googledrive.RequiredScope, func(item string) bool { +// return !strings.Contains(service.Scope, item) +// }); found { +// return serializer.ParamErrDeprecated(fmt.Sprintf("Missing required scope: %s", missing), nil) +// } +// +// policyID, ok := util.GetSession(c, "googledrive_oauth_policy").(uint) +// if !ok { +// return serializer.ErrDeprecated(serializer.CodeNotFound, "", nil) +// } +// +// util.DeleteSession(c, "googledrive_oauth_policy") +// +// policy, err := model.GetPolicyByID(policyID) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodePolicyNotExist, "", nil) +// } +// +// client, err := googledrive.NewClient(&policy) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) +// } +// +// credential, err := client.ObtainToken(c, service.Code, "") +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) +// } +// +// // 更新存储策略的 RefreshToken +// client.Policy.AccessKey = credential.RefreshToken +// if err := client.Policy.SaveAndClearCache(); err != nil { +// return serializer.DBErrDeprecated("Failed to update RefreshToken", err) +// } +// +// cache.Deletes([]string{client.Policy.AccessKey}, googledrive.TokenCachePrefix) +// return serializer.Response{} +//} // OdAuth OneDrive 更新认证信息 func (service *OauthService) OdAuth(c *gin.Context) serializer.Response { - if service.Error != "" { - return serializer.ParamErr(service.ErrorMsg, nil) - } - - policyID, ok := util.GetSession(c, "onedrive_oauth_policy").(uint) - if !ok { - return serializer.Err(serializer.CodeNotFound, "", nil) - } - - util.DeleteSession(c, "onedrive_oauth_policy") - - policy, err := model.GetPolicyByID(policyID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", nil) - } - - client, err := onedrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) - } - - credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code)) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) - } - - // 更新存储策略的 RefreshToken - client.Policy.AccessKey = credential.RefreshToken - if err := client.Policy.SaveAndClearCache(); err != nil { - return serializer.DBErr("Failed to update RefreshToken", err) - } - - cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") - if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") { - if err := querySharePointSiteID(c, client.Policy); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to query SharePoint site ID", err) - } - } + //if service.Error != "" { + // return serializer.ParamErrDeprecated(service.ErrorMsg, nil) + //} + // + //policyID, ok := util.GetSession(c, "onedrive_oauth_policy").(uint) + //if !ok { + // return serializer.ErrDeprecated(serializer.CodeNotFound, "", nil) + //} + // + //util.DeleteSession(c, "onedrive_oauth_policy") + // + //policy, err := model.GetPolicyByID(policyID) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodePolicyNotExist, "", nil) + //} + // + //client, err := onedrive.NewClient(&policy) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) + //} + // + //credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code)) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) + //} + // + //// 更新存储策略的 RefreshToken + //client.Policy.AccessKey = credential.RefreshToken + //if err := client.Policy.SaveAndClearCache(); err != nil { + // return serializer.DBErrDeprecated("Failed to update RefreshToken", err) + //} + // + //cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") + //if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") { + // if err := querySharePointSiteID(c, client.Policy); err != nil { + // return serializer.ErrDeprecated(serializer.CodeInternalSetting, "Failed to query SharePoint basic ID", err) + // } + //} return serializer.Response{} } -func querySharePointSiteID(ctx context.Context, policy *model.Policy) error { - client, err := onedrive.NewClient(policy) - if err != nil { - return err - } - - id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver) - if err != nil { - return err - } - - client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id) - if err := client.Policy.SaveAndClearCache(); err != nil { - return err - } - - return nil -} +//func querySharePointSiteID(ctx context.Context, policy *model.Policy) error { +//client, err := onedrive.NewClient(policy) +//if err != nil { +// return err +//} +// +//id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver) +//if err != nil { +// return err +//} +// +//client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id) +//if err := client.Policy.SaveAndClearCache(); err != nil { +// return err +//} + +//return nil +//} diff --git a/service/callback/upload.go b/service/callback/upload.go index 0dd7924c..057c75eb 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -1,33 +1,17 @@ package callback import ( - "context" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "strings" - - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/onedrive" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" "github.com/gin-gonic/gin" ) -// CallbackProcessService 上传请求回调正文接口 -type CallbackProcessService interface { - GetBody() serializer.UploadCallback -} - // RemoteUploadCallbackService 远程存储上传回调请求服务 type RemoteUploadCallbackService struct { - Data serializer.UploadCallback `json:"data" binding:"required"` -} - -// GetBody 返回回调正文 -func (service RemoteUploadCallbackService) GetBody() serializer.UploadCallback { - return service.Data } // UploadCallbackService OOS/七牛云存储上传回调请求服务 @@ -35,7 +19,7 @@ type UploadCallbackService struct { Name string `json:"name"` SourceName string `json:"source_name"` PicInfo string `json:"pic_info"` - Size uint64 `json:"size"` + Size int64 `json:"size"` } // UpyunCallbackService 又拍云上传回调请求服务 @@ -53,209 +37,109 @@ type OneDriveCallback struct { Meta *onedrive.FileInfo } -// COSCallback COS 客户端回调正文 -type COSCallback struct { - Bucket string `form:"bucket"` - Etag string `form:"etag"` -} - // S3Callback S3 客户端回调正文 type S3Callback struct { } -// GetBody 返回回调正文 -func (service UpyunCallbackService) GetBody() serializer.UploadCallback { - res := serializer.UploadCallback{} - if service.Width != "" { - res.PicInfo = service.Width + "," + service.Height - } - - return res -} - -// GetBody 返回回调正文 -func (service UploadCallbackService) GetBody() serializer.UploadCallback { - return serializer.UploadCallback{ - PicInfo: service.PicInfo, - } -} - -// GetBody 返回回调正文 -func (service OneDriveCallback) GetBody() serializer.UploadCallback { - var picInfo = "0,0" - if service.Meta.Image.Width != 0 { - picInfo = fmt.Sprintf("%d,%d", service.Meta.Image.Width, service.Meta.Image.Height) - } - return serializer.UploadCallback{ - PicInfo: picInfo, - } -} - -// GetBody 返回回调正文 -func (service COSCallback) GetBody() serializer.UploadCallback { - return serializer.UploadCallback{ - PicInfo: "", - } -} - -// GetBody 返回回调正文 -func (service S3Callback) GetBody() serializer.UploadCallback { - return serializer.UploadCallback{ - PicInfo: "", - } -} - // ProcessCallback 处理上传结果回调 -func ProcessCallback(service CallbackProcessService, c *gin.Context) serializer.Response { - callbackBody := service.GetBody() - - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromCallback(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, err.Error(), err) - } - defer fs.Recycle() +func ProcessCallback(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() // 获取上传会话 - uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 查找上传会话创建的占位文件 - file, err := model.GetFilesByUploadSession(uploadSession.Key, fs.User.ID) - if err != nil { - return serializer.Err(serializer.CodeUploadSessionExpired, "LocalUpload session file placeholder not exist", err) - } - - fileData := fsctx.FileStream{ - Size: uploadSession.Size, - Name: uploadSession.Name, - VirtualPath: uploadSession.VirtualPath, - SavePath: uploadSession.SavePath, - Mode: fsctx.Nop, - Model: file, - LastModified: uploadSession.LastModified, - } - - // 占位符未扣除容量需要校验和扣除 - if !fs.Policy.IsUploadPlaceholderWithSize() { - fs.Use("AfterUpload", filesystem.HookValidateCapacity) - fs.Use("AfterUpload", filesystem.HookChunkUploaded) - } - - fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile(callbackBody.PicInfo)) - fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - err = fs.Upload(context.Background(), &fileData) - if err != nil { - return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) - } - - return serializer.Response{} -} - -// PreProcess 对OneDrive客户端回调进行预处理验证 -func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromCallback(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 获取回调会话 - uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 获取文件信息 - info, err := fs.Handler.(onedrive.Driver).Client.Meta(context.Background(), "", uploadSession.SavePath) - if err != nil { - return serializer.Err(serializer.CodeQueryMetaFailed, "", err) - } - - // 验证与回调会话中是否一致 - actualPath := strings.TrimPrefix(uploadSession.SavePath, "/") - isSizeCheckFailed := uploadSession.Size != info.Size - - // SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容 - // See: https://github.com/OneDrive/onedrive-api-docs/issues/935 - if (strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") || strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.cn")) && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) { - isSizeCheckFailed = false - } - - if isSizeCheckFailed || !strings.EqualFold(info.GetSourcePath(), actualPath) { - fs.Handler.(onedrive.Driver).Client.Delete(context.Background(), []string{info.GetSourcePath()}) - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } - service.Meta = info - return ProcessCallback(service, c) -} - -// PreProcess 对COS客户端回调进行预处理 -func (service *COSCallback) PreProcess(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromCallback(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 获取回调会话 - uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 获取文件信息 - info, err := fs.Handler.(cos.Driver).Meta(context.Background(), uploadSession.SavePath) - if err != nil { - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } - - // 验证实际文件信息与回调会话中是否一致 - if uploadSession.Size != info.Size || uploadSession.Key != info.CallbackKey { - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } - - return ProcessCallback(service, c) -} - -// PreProcess 对S3客户端回调进行预处理 -func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromCallback(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 获取回调会话 - uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 获取文件信息 - info, err := fs.Handler.(*s3.Driver).Meta(context.Background(), uploadSession.SavePath) - if err != nil { - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } + uploadSession := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession) - // 验证实际文件信息与回调会话中是否一致 - if uploadSession.Size != info.Size { - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } - - return ProcessCallback(service, c) -} - -// PreProcess 对从机客户端回调进行预处理验证 -func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromCallback(c) + _, err := m.CompleteUpload(c, uploadSession) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 获取回调会话 - uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) - - // 验证文件大小 - if uploadSession.Size != service.Size { - fs.Handler.Delete(context.Background(), []string{uploadSession.SavePath}) - return serializer.Err(serializer.CodeMetaMismatch, "", err) - } - - return ProcessCallback(service, c) -} + return fmt.Errorf("failed to complete upload: %w", err) + } + + return nil +} + +//// PreProcess 对OneDrive客户端回调进行预处理验证 +//func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response { +// // 创建文件系统 +// fs, err := filesystem.NewFileSystemFromCallback(c) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) +// } +// defer fs.Recycle() +// +// // 获取回调会话 +// uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) +// +// // 获取文件信息 +// info, err := fs.Handler.(onedrive.Driver).Client.Meta(context.Background(), "", uploadSession.SavePath) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeQueryMetaFailed, "", err) +// } +// +// // 验证与回调会话中是否一致 +// actualPath := strings.TrimPrefix(uploadSession.SavePath, "/") +// isSizeCheckFailed := uploadSession.Size != info.Size +// +// // SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容 +// // See: https://github.com/OneDrive/onedrive-api-docs/issues/935 +// if (strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") || strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.cn")) && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) { +// isSizeCheckFailed = false +// } +// +// if isSizeCheckFailed || !strings.EqualFold(info.GetSourcePath(), actualPath) { +// fs.Handler.(onedrive.Driver).Client.Delete(context.Background(), []string{info.GetSourcePath()}) +// return serializer.ErrDeprecated(serializer.CodeMetaMismatch, "", err) +// } +// service.Meta = info +// return ProcessCallback(c) +//} +// + +// +//// PreProcess 对S3客户端回调进行预处理 +//func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response { +// // 创建文件系统 +// fs, err := filesystem.NewFileSystemFromCallback(c) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) +// } +// defer fs.Recycle() +// +// // 获取回调会话 +// uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) +// +// // 获取文件信息 +// info, err := fs.Handler.(*s3.Driver).Meta(context.Background(), uploadSession.SavePath) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeMetaMismatch, "", err) +// } +// +// // 验证实际文件信息与回调会话中是否一致 +// if uploadSession.Size != info.Size { +// return serializer.ErrDeprecated(serializer.CodeMetaMismatch, "", err) +// } +// +// return ProcessCallback(service, c) +//} +// +//// PreProcess 对OneDrive客户端回调进行预处理验证 +//func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response { +// // 创建文件系统 +// fs, err := filesystem.NewFileSystemFromCallback(c) +// if err != nil { +// return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) +// } +// defer fs.Recycle() +// +// // 获取回调会话 +// uploadSession := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession) +// +// // 验证文件大小 +// if uploadSession.Size != service.Size { +// fs.Handler.Delete(context.Background(), []string{uploadSession.SavePath}) +// return serializer.ErrDeprecated(serializer.CodeMetaMismatch, "", err) +// } +// +// return ProcessCallback(service, c) +//} diff --git a/service/explorer/directory.go b/service/explorer/directory.go deleted file mode 100644 index cd03999f..00000000 --- a/service/explorer/directory.go +++ /dev/null @@ -1,68 +0,0 @@ -package explorer - -import ( - "context" - - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" -) - -// DirectoryService 创建新目录服务 -type DirectoryService struct { - Path string `uri:"path" json:"path" binding:"required,min=1,max=65535"` -} - -// ListDirectory 列出目录内容 -func (service *DirectoryService) ListDirectory(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 获取子项目 - objects, err := fs.List(ctx, service.Path, nil) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - var parentID uint - if len(fs.DirTarget) > 0 { - parentID = fs.DirTarget[0].ID - } - - return serializer.Response{ - Code: 0, - Data: serializer.BuildObjectList(parentID, objects, fs.Policy), - } -} - -// CreateDirectory 创建目录 -func (service *DirectoryService) CreateDirectory(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 创建目录 - _, err = fs.CreateDirectory(ctx, service.Path) - if err != nil { - return serializer.Err(serializer.CodeCreateFolderFailed, err.Error(), err) - } - return serializer.Response{ - Code: 0, - } - -} diff --git a/service/explorer/entity.go b/service/explorer/entity.go new file mode 100644 index 00000000..d182a26f --- /dev/null +++ b/service/explorer/entity.go @@ -0,0 +1,118 @@ +package explorer + +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" +) + +type ( + EntityDownloadParameterCtx struct{} + EntityDownloadService struct { + Name string `uri:"name" binding:"required"` + SpeedLimit int64 `uri:"speed"` + Src string `uri:"src"` + } +) + +// Serve serves file content +func (s *EntityDownloadService) Serve(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + entitySource, err := m.GetEntitySource(c, hashid.FromContext(c)) + if err != nil { + return fmt.Errorf("failed to get entity source: %w", err) + } + + defer entitySource.Close() + + // Set cache header for public resource + settings := dep.SettingProvider() + maxAge := settings.PublicResourceMaxAge(c) + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", maxAge)) + + isDownload := c.Query(routes.IsDownloadQuery) != "" + isThumb := c.Query(routes.IsThumbQuery) != "" + entitySource.Serve(c.Writer, c.Request, + entitysource.WithSpeedLimit(s.SpeedLimit), + entitysource.WithDownload(isDownload), + entitysource.WithDisplayName(s.Name), + entitysource.WithContext(c), + entitysource.WithThumb(isThumb), + ) + return nil +} + +type ( + SetCurrentVersionParamCtx struct{} + SetCurrentVersionService struct { + Uri string `uri:"uri" binding:"required"` + Version string `uri:"version" binding:"required"` + } +) + +// Set sets the current version of the file +func (s *SetCurrentVersionService) Set(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(s.Uri) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + versionId, err := dep.HashIDEncoder().Decode(s.Version, hashid.EntityID) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown version id", err) + } + + if err := m.SetCurrentVersion(c, uri, versionId); err != nil { + return fmt.Errorf("failed to set current version: %w", err) + } + + return nil +} + +type ( + DeleteVersionParamCtx struct{} + DeleteVersionService struct { + Uri string `uri:"uri" binding:"required"` + Version string `uri:"version" binding:"required"` + } +) + +// Delete deletes the version of the file +func (s *DeleteVersionService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(s.Uri) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + versionId, err := dep.HashIDEncoder().Decode(s.Version, hashid.EntityID) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown version id", err) + } + + if err := m.DeleteVersion(c, uri, versionId); err != nil { + return fmt.Errorf("failed to delete version: %w", err) + } + + return nil +} diff --git a/service/explorer/file.go b/service/explorer/file.go index 1c9d870d..c17f549a 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -2,24 +2,28 @@ package explorer import ( "context" - "encoding/base64" - "encoding/json" + "encoding/gob" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io/ioutil" "net/http" - "net/url" - "path" - "strconv" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "github.com/samber/lo" ) // SingleFileService 对单文件进行操作的五福,path为文件完整路径 @@ -37,498 +41,633 @@ type FileAnonymousGetService struct { Name string `uri:"name" binding:"required"` } -// DownloadService 文件下載服务 -type DownloadService struct { - ID string `uri:"id" binding:"required"` +func init() { + gob.Register(ArchiveDownloadSession{}) } -// ArchiveService 文件流式打包下載服务 -type ArchiveService struct { - ID string `uri:"sessionID" binding:"required"` +// List 列出从机上的文件 +func (service *SlaveListService) List(c *gin.Context) serializer.Response { + //// 创建文件系统 + //fs, err := filesystem.NewAnonymousFileSystem() + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeCreateFSError, "", err) + //} + //defer fs.Recycle() + // + //objects, err := fs.Handler.List(context.Background(), service.Path, service.Recursive) + //if err != nil { + // return serializer.ErrDeprecated(serializer.CodeIOFailed, "Cannot list files", err) + //} + // + //res, _ := json.Marshal(objects) + //return serializer.Response{Data: string(res)} + + return serializer.Response{} } -// New 创建新文件 -func (service *SingleFileService) Create(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) +// ArchiveService 文件流式打包下載服务 +type ( + ArchiveService struct { + ID string `uri:"sessionID" binding:"required"` + } + ArchiveParamCtx struct{} +) + +// DownloadArchived 通过预签名 URL 打包下载 +func (service *ArchiveService) DownloadArchived(c *gin.Context) error { + dep := dependency.FromContext(c) + archiveSessionRaw, found := dep.KV().Get(ArchiveDownloadSessionPrefix + service.ID) + if !found { + return serializer.NewError(serializer.CodeNotFound, "Archive session not exist", nil) + } + + // Switch to user context + archiveSession := archiveSessionRaw.(ArchiveDownloadSession) + requester, err := dep.UserClient().GetLoginUserByID(c, archiveSession.RequesterID) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return serializer.NewError(serializer.CodeNotFound, "Requester not found", err) } - defer fs.Recycle() - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + util.WithValue(c, inventory.UserCtx{}, requester) - // 给文件系统分配钩子 - fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("AfterUpload", filesystem.GenericAfterUpload) + fm := manager.NewFileManager(dep, requester) + defer fm.Recycle() - // 上传空文件 - err = fs.Upload(ctx, &fsctx.FileStream{ - File: ioutil.NopCloser(strings.NewReader("")), - Size: 0, - VirtualPath: path.Dir(service.Path), - Name: path.Base(service.Path), - }) - if err != nil { - return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) + // 开始打包 + c.Header("Content-Disposition", "attachment;") + c.Header("Content-Type", "application/zip") + + if _, err := fm.CreateArchive(c, archiveSession.Uris, c.Writer); err != nil { + return serializer.NewError(serializer.CodeIOFailed, "Failed to create archive", err) } - return serializer.Response{ - Code: 0, + return nil +} + +type ( + GetDirectLinkParamCtx struct{} + GetDirectLinkService struct { + Uris []string `json:"uris" binding:"required,min=1"` } +) + +func (s *GetDirectLinkService) GetUris() []string { + return s.Uris } -// List 列出从机上的文件 -func (service *SlaveListService) List(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) +// Sources 批量获取对象的外链 +func (s *GetDirectLinkService) Get(c *gin.Context) ([]DirectLinkResponse, error) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + + if u.Edges.Group.Settings.SourceBatchSize == 0 { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "", nil) } - defer fs.Recycle() - objects, err := fs.Handler.List(context.Background(), service.Path, service.Recursive) + if len(s.Uris) > u.Edges.Group.Settings.SourceBatchSize { + return nil, serializer.NewError(serializer.CodeBatchSourceSize, "", nil) + } + + m := manager.NewFileManager(dep, u) + defer m.Recycle() + + uris, err := fs.NewUriFromStrings(s.Uris...) if err != nil { - return serializer.Err(serializer.CodeIOFailed, "Cannot list files", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - res, _ := json.Marshal(objects) - return serializer.Response{Data: string(res)} + res, err := m.GetDirectLink(c, uris...) + return BuildDirectLinkResponse(res), err } -// DownloadArchived 通过预签名 URL 打包下载 -func (service *ArchiveService) DownloadArchived(ctx context.Context, c *gin.Context) serializer.Response { - userRaw, exist := cache.Get("archive_user_" + service.ID) - if !exist { - return serializer.Err(serializer.CodeNotFound, "Archive session not exist", nil) +const defaultPageSize = 100 + +type ( + // ListFileParameterCtx define key fore ListFileService + ListFileParameterCtx struct{} + + // ListFileService stores parameters for list file by URI + ListFileService struct { + Uri string `uri:"uri" form:"uri" json:"uri" binding:"required"` + Page int `uri:"page" form:"page" json:"page" binding:"min=0"` + PageSize int `uri:"page_size" form:"page_size" json:"page_size" binding:"min=10"` + OrderBy string `uri:"order_by" form:"order_by" json:"order_by"` + OrderDirection string `uri:"order_direction" form:"order_direction" json:"order_direction"` + NextPageToken string `uri:"next_page_token" form:"next_page_token" json:"next_page_token"` } - user := userRaw.(model.User) +) - // 创建文件系统 - fs, err := filesystem.NewFileSystem(&user) +// List all files for given path +func (service *ListFileService) List(c *gin.Context) (*ListResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + pageSize := service.PageSize + if pageSize == 0 { + pageSize = defaultPageSize + } + + streamed := false + hasher := dep.HashIDEncoder() + parent, res, err := m.List(c, uri, &manager.ListArgs{ + Page: service.Page, + PageSize: pageSize, + Order: service.OrderBy, + OrderDirection: service.OrderDirection, + PageToken: service.NextPageToken, + StreamResponseCallback: func(parent fs.File, files []fs.File) { + if !streamed { + WriteEventSourceHeader(c) + streamed = true + } + + WriteEventSource(c, "file", lo.Map(files, func(file fs.File, index int) *FileResponse { + return BuildFileResponse(c, user, file, hasher, nil) + })) + }, + }) + if err != nil { + return nil, err } - defer fs.Recycle() - // 查找打包的临时文件 - archiveSession, exist := cache.Get("archive_" + service.ID) - if !exist { - return serializer.Err(serializer.CodeNotFound, "Archive session not exist", nil) + listResponse := BuildListResponse(c, user, parent, res, hasher) + if streamed { + WriteEventSource(c, "list", listResponse) + return nil, ErrSSETakeOver } - // 开始打包 - c.Header("Content-Disposition", "attachment;") - c.Header("Content-Type", "application/zip") - itemService := archiveSession.(ItemIDService) - items := itemService.Raw() - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - err = fs.Compress(ctx, c.Writer, items.Dirs, items.Items, true) + return listResponse, nil +} + +type ( + CreateFileParameterCtx struct{} + CreateFileService struct { + Uri string `json:"uri" binding:"required"` + Type string `json:"type" binding:"required,eq=file|eq=folder"` + Metadata map[string]string `json:"metadata"` + ErrOnConflict bool `json:"err_on_conflict"` + } +) + +func (service *CreateFileService) Create(c *gin.Context) (*FileResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeNotSet, "Failed to compress file", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - return serializer.Response{ - Code: 0, + fileType := types.FileTypeFromString(service.Type) + opts := []fs.Option{ + fs.WithMetadata(service.Metadata), + } + if service.ErrOnConflict { + opts = append(opts, dbfs.WithErrorOnConflict()) } + file, err := m.Create(c, uri, fileType, opts...) + if err != nil { + return nil, err + } + + return BuildFileResponse(c, user, file, dep.HashIDEncoder(), nil), nil } -// Download 签名的匿名文件下载 -func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Context) serializer.Response { - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) +type ( + RenameFileParameterCtx struct{} + RenameFileService struct { + Uri string `json:"uri" binding:"required"` + NewName string `json:"new_name" binding:"required,min=1,max=255"` } - defer fs.Recycle() +) + +func (service *RenameFileService) Rename(c *gin.Context) (*FileResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - // 查找文件 - err = fs.SetTargetFileByIDs([]uint{service.ID}) + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - // 获取文件流 - rs, err := fs.GetDownloadContent(ctx, 0) - defer rs.Close() + file, err := m.Rename(c, uri, service.NewName) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, err } - // 发送文件 - http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs) + return BuildFileResponse(c, user, file, dep.HashIDEncoder(), nil), nil +} - return serializer.Response{ - Code: 0, +type ( + MoveFileParameterCtx struct{} + MoveFileService struct { + Uris []string `json:"uris" binding:"required,min=1"` + Dst string `json:"dst" binding:"required"` + Copy bool `json:"copy"` } +) + +func (s *MoveFileService) GetUris() []string { + return s.Uris } -// Source 重定向到文件的有效原始链接 -func (service *FileAnonymousGetService) Source(ctx context.Context, c *gin.Context) serializer.Response { - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() +func (s *MoveFileService) Move(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - // 查找文件 - err = fs.SetTargetFileByIDs([]uint{service.ID}) + uris, err := fs.NewUriFromStrings(s.Uris...) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - // 获取文件流 - ttl := int64(model.GetIntSetting("preview_timeout", 60)) - res, err := fs.SignURL(ctx, &fs.FileTarget[0], ttl, false) + dst, err := fs.NewUriFromString(s.Dst) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return serializer.NewError(serializer.CodeParamErr, "unknown destination uri", err) } - c.Header("Cache-Control", fmt.Sprintf("max-age=%d", ttl)) - return serializer.Response{ - Code: -302, - Data: res, - } + return m.MoveOrCopy(c, uris, dst, s.Copy) } -// CreateDocPreviewSession 创建DOC文件预览会话,返回预览地址 -func (service *FileIDService) CreateDocPreviewSession(ctx context.Context, c *gin.Context, editable bool) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) +type ( + FileUpdateParameterCtx struct{} + FileUpdateService struct { + Uri string `form:"uri" binding:"required"` + Previous string `form:"previous"` } - defer fs.Recycle() - - // 获取对象id - objectID, _ := c.Get("object_id") +) - // 如果上下文中已有File对象,则重设目标 - if file, ok := ctx.Value(fsctx.FileModelCtx).(*model.File); ok { - fs.SetTargetFile(&[]model.File{*file}) - objectID = uint(0) +func (service *FileUpdateService) PutContent(c *gin.Context, ls fs.LockSession) (*FileResponse, error) { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + // 取得文件大小 + rc, fileSize, err := request.SniffContentLength(c.Request) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "invalid content length", err) } - // 如果上下文中已有Folder对象,则重设根目录 - if folder, ok := ctx.Value(fsctx.FolderModelCtx).(*model.Folder); ok { - fs.Root = folder - path := ctx.Value(fsctx.PathCtx).(string) - err := fs.ResetFileIfNotExist(ctx, path) - if err != nil { - return serializer.Err(serializer.CodeNotFound, err.Error(), err) - } - objectID = uint(0) + if fileSize > settings.MaxOnlineEditSize(c) { + return nil, fs.ErrFileSizeTooBig } - // 获取文件临时下载地址 - downloadURL, err := fs.GetDownloadURL(ctx, objectID.(uint), "doc_preview_timeout") + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - // For newer version of Cloudreve - Local Policy - // When do not use a cdn, the downloadURL withouts hosts, like "/api/v3/file/download/xxx" - if strings.HasPrefix(downloadURL, "/") { - downloadURI, err := url.Parse(downloadURL) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - downloadURL = model.GetSiteURL().ResolveReference(downloadURI).String() + fileData := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: uri, + PreviousVersion: service.Previous, + Size: fileSize, + }, + File: rc, + Mode: fs.ModeOverwrite, } - var resp serializer.DocPreviewSession + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - // Use WOPI preview if available - if model.IsTrueVal(model.GetSettingByName("wopi_enabled")) && wopi.Default != nil { - maxSize := model.GetIntSetting("maxEditSize", 0) - if maxSize > 0 && fs.FileTarget[0].Size > uint64(maxSize) { - return serializer.Err(serializer.CodeFileTooLarge, "", nil) - } - - action := wopi.ActionPreview - if editable { - action = wopi.ActionEdit - } + // Update file + var ctx context.Context = c + if ls != nil { + ctx = fs.LockSessionToContext(c, ls) + } + res, err := m.Update(ctx, fileData) + if err != nil { + return nil, fmt.Errorf("failed to update file: %w", err) + } - session, err := wopi.Default.NewSession(fs.FileTarget[0].UserID, &fs.FileTarget[0], action) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to create WOPI session", err) - } + return BuildFileResponse(c, user, res, dep.HashIDEncoder(), nil), nil +} - resp.URL = session.ActionURL.String() - resp.AccessTokenTTL = session.AccessTokenTTL - resp.AccessToken = session.AccessToken - return serializer.Response{ - Code: 0, - Data: resp, - } +type ( + FileURLParameterCtx struct{} + FileURLService struct { + Uris []string `json:"uris" binding:"required"` + Download bool `json:"download"` + Redirect bool `json:"redirect"` // Only works if Uris count is 1. + Entity string `json:"entity"` // Only works if Uris count is 1. + UsePrimarySiteURL bool `json:"use_primary_site_url"` + SkipError bool `json:"skip_error"` + Archive bool `json:"archive"` + NoCache bool `json:"no_cache"` + } + FileURLResponse struct { + Urls []string `json:"urls"` + Expires *time.Time `json:"expires"` + } + ArchiveDownloadSession struct { + Uris []*fs.URI `json:"uris"` + RequesterID int `json:"requester_id"` } +) - // 生成最终的预览器地址 - srcB64 := base64.StdEncoding.EncodeToString([]byte(downloadURL)) - srcEncoded := url.QueryEscape(downloadURL) - srcB64Encoded := url.QueryEscape(srcB64) - resp.URL = util.Replace(map[string]string{ - "{$src}": srcEncoded, - "{$srcB64}": srcB64Encoded, - "{$name}": url.QueryEscape(fs.FileTarget[0].Name), - }, model.GetSettingByName("office_preview_service")) +const ( + ArchiveDownloadSessionPrefix = "archive_" +) - return serializer.Response{ - Code: 0, - Data: resp, - } +func (s *FileURLService) GetUris() []string { + return s.Uris } -// CreateDownloadSession 创建下载会话,获取下载URL -func (service *FileIDService) CreateDownloadSession(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) +// GetArchiveDownloadSession generates temporary download session for archive download. +func (s *FileURLService) GetArchiveDownloadSession(c *gin.Context) (*FileURLResponse, error) { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + user := inventory.UserFromContext(c) + + uris, err := fs.NewUriFromStrings(s.Uris...) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - defer fs.Recycle() - - // 获取对象id - objectID, _ := c.Get("object_id") - // 获取下载地址 - downloadURL, err := fs.GetDownloadURL(ctx, objectID.(uint), "download_timeout") - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionArchiveDownload)) { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "", nil) } - return serializer.Response{ - Code: 0, - Data: downloadURL, + // Create archive download session + archiveSession := &ArchiveDownloadSession{ + Uris: uris, + RequesterID: user.ID, + } + sessionId := uuid.Must(uuid.NewV4()).String() + ttl := settings.ArchiveDownloadSessionTTL(c) + expire := time.Now().Add(time.Duration(ttl) * time.Second) + if err := dep.KV().Set(ArchiveDownloadSessionPrefix+sessionId, *archiveSession, ttl); err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "failed to create archive download session", err) } -} -// Download 通过签名URL的文件下载,无需登录 -func (service *DownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) + base := settings.SiteURL(c) + downloadUrl := routes.MasterArchiveDownloadUrl(base, sessionId) + finalUrl, err := auth.SignURI(c, dep.GeneralAuth(), downloadUrl.String(), &expire) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeInternalSetting, "failed to sign archive download url", err) } - defer fs.Recycle() - // 查找打包的临时文件 - file, exist := cache.Get("download_" + service.ID) - if !exist { - return serializer.Err(serializer.CodeNotFound, "Download session not exist", nil) + return &FileURLResponse{ + Urls: []string{finalUrl.String()}, + Expires: &expire, + }, nil +} + +func (s *FileURLService) Get(c *gin.Context) (*FileURLResponse, error) { + if s.Archive { + return s.GetArchiveDownloadSession(c) } - fs.FileTarget = []model.File{file.(model.File)} - // 开始处理下载 - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - rs, err := fs.GetDownloadContent(ctx, 0) + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uris, err := fs.NewUriFromStrings(s.Uris...) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - defer rs.Close() - // 设置文件名 - c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") + // Request entity URL + expire := time.Now().Add(settings.EntityUrlValidDuration(c)) + urlReq := lo.Map(uris, func(uri *fs.URI, _ int) manager.GetEntityUrlArgs { + return manager.GetEntityUrlArgs{ + URI: uri, + PreferredEntityID: s.Entity, + } + }) + + var ctx context.Context = c + if s.UsePrimarySiteURL { + ctx = setting.UseFirstSiteUrl(ctx) + } - if fs.User.Group.OptionsSerialized.OneTimeDownload { - // 清理资源,删除临时文件 - _ = cache.Deletes([]string{service.ID}, "download_") + res, earliestExpire, err := m.GetEntityUrls(ctx, urlReq, + fs.WithDownloadSpeed(int64(user.Edges.Group.SpeedLimit)), + fs.WithIsDownload(s.Download), + fs.WithNoCache(s.NoCache), + fs.WithUrlExpire(&expire), + ) + if err != nil && !s.SkipError { + return nil, fmt.Errorf("failed to get entity url: %w", err) } - // 发送文件 - http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, fs.FileTarget[0].UpdatedAt, rs) + //if !s.NoCache && earliestExpire != nil { + // // Set cache header + // cacheTTL := int(earliestExpire.Sub(time.Now()).Seconds() - float64(settings.EntityUrlCacheMargin(c))) + // if cacheTTL > 0 { + // c.Header("Cache-Control", fmt.Sprintf("private, max-age=%d", cacheTTL)) + // } + //} - return serializer.Response{ - Code: 0, + if s.Redirect && len(uris) == 1 { + c.Redirect(http.StatusFound, res[0]) + return nil, nil } + + return &FileURLResponse{ + Urls: res, + Expires: earliestExpire, + }, nil } -// PreviewContent 预览文件,需要登录会话, isText - 是否为文本文件,文本文件会 -// 强制经由服务端中转 -func (service *FileIDService) PreviewContent(ctx context.Context, c *gin.Context, isText bool) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) +type ( + FileThumbParameterCtx struct{} + FileThumbService struct { + Uri string `form:"uri" binding:"required"` } - defer fs.Recycle() + FileThumbResponse struct { + Url string `json:"url"` + Expires *time.Time `json:"expires"` + } +) - // 获取对象id - objectID, _ := c.Get("object_id") +// Get redirect to thumb file. +func (s *FileThumbService) Get(c *gin.Context) (*FileThumbResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - // 如果上下文中已有File对象,则重设目标 - if file, ok := ctx.Value(fsctx.FileModelCtx).(*model.File); ok { - fs.SetTargetFile(&[]model.File{*file}) - objectID = uint(0) + uri, err := fs.NewUriFromString(s.Uri) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - // 如果上下文中已有Folder对象,则重设根目录 - if folder, ok := ctx.Value(fsctx.FolderModelCtx).(*model.Folder); ok { - fs.Root = folder - path := ctx.Value(fsctx.PathCtx).(string) - err := fs.ResetFileIfNotExist(ctx, path) - if err != nil { - return serializer.Err(serializer.CodeFileNotFound, err.Error(), err) - } - objectID = uint(0) + // Get thumbnail + thumb, err := m.Thumbnail(c, uri) + if err != nil { + return nil, fmt.Errorf("failed to get thumbnail: %w", err) } - // 获取文件预览响应 - resp, err := fs.Preview(ctx, objectID.(uint), isText) + expire := time.Now().Add(dep.SettingProvider().EntityUrlValidDuration(c)) + thumbUrl, err := thumb.Url(c, entitysource.WithExpire(&expire)) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, fmt.Errorf("failed to get thumbnail url: %w", err) } - // 重定向到文件源 - if resp.Redirect { - c.Header("Cache-Control", fmt.Sprintf("max-age=%d", resp.MaxAge)) - return serializer.Response{ - Code: -301, - Data: resp.URL, - } + return &FileThumbResponse{ + Url: thumbUrl.Url, + Expires: thumbUrl.ExpireAt, + }, nil +} + +type ( + DeleteFileParameterCtx struct{} + DeleteFileService struct { + Uris []string `json:"uris" binding:"required,min=1"` + UnlinkOnly bool `json:"unlink"` + SkipSoftDelete bool `json:"skip_soft_delete"` } +) + +func (s *DeleteFileService) GetUris() []string { + return s.Uris +} - // 直接返回文件内容 - defer resp.Content.Close() +func (s *DeleteFileService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - if isText { - c.Header("Cache-Control", "no-cache") + uris, err := fs.NewUriFromStrings(s.Uris...) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, fs.FileTarget[0].UpdatedAt, resp.Content) + if s.UnlinkOnly && !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionAdvanceDelete)) { + return serializer.NewError(serializer.CodeNoPermissionErr, "advance delete permission is required", nil) + } - return serializer.Response{ - Code: 0, + // Delete file + if err = m.Delete(c, uris, fs.WithUnlinkOnly(s.UnlinkOnly), fs.WithSkipSoftDelete(s.SkipSoftDelete)); err != nil { + return fmt.Errorf("failed to delete file: %w", err) } + + return nil } -// PutContent 更新文件内容 -func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) serializer.Response { - // 创建上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func (s *DeleteFileService) Restore(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - // 取得文件大小 - fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64) + uris, err := fs.NewUriFromStrings(s.Uris...) if err != nil { - - return serializer.ParamErr("Invalid content-length value", err) + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - fileData := fsctx.FileStream{ - MimeType: c.Request.Header.Get("Content-Type"), - File: c.Request.Body, - Size: fileSize, - Mode: fsctx.Overwrite, + // Delete file + if err = m.Restore(c, uris...); err != nil { + return fmt.Errorf("failed to restore file: %w", err) + } - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) - - // 取得现有文件 - fileID, _ := c.Get("object_id") - originFile, _ := model.GetFilesByIDs([]uint{fileID.(uint)}, fs.User.ID) - if len(originFile) == 0 { - return serializer.Err(serializer.CodeFileNotFound, "", nil) - } - fileData.Name = originFile[0].Name - - // 检查此文件是否有软链接 - fileList, err := model.RemoveFilesWithSoftLinks([]model.File{originFile[0]}) - if err == nil && len(fileList) == 0 { - // 如果包含软连接,应重新生成新文件副本,并更新source_name - originFile[0].SourceName = fs.GenerateSavePath(uploadCtx, &fileData) - fileData.Mode &= ^fsctx.Overwrite - fs.Use("AfterUpload", filesystem.HookUpdateSourceName) - fs.Use("AfterUploadCanceled", filesystem.HookUpdateSourceName) - fs.Use("AfterUploadCanceled", filesystem.HookCleanFileContent) - fs.Use("AfterUploadCanceled", filesystem.HookClearFileSize) - fs.Use("AfterValidateFailed", filesystem.HookUpdateSourceName) - fs.Use("AfterValidateFailed", filesystem.HookCleanFileContent) - fs.Use("AfterValidateFailed", filesystem.HookClearFileSize) - } - - // 给文件系统分配钩子 - fs.Use("BeforeUpload", filesystem.HookResetPolicy) - fs.Use("BeforeUpload", filesystem.HookValidateFile) - fs.Use("BeforeUpload", filesystem.HookValidateCapacityDiff) - fs.Use("AfterUpload", filesystem.GenericAfterUpdate) - - // 执行上传 - uploadCtx = context.WithValue(uploadCtx, fsctx.FileModelCtx, originFile[0]) - err = fs.Upload(uploadCtx, &fileData) - if err != nil { - return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) + return nil +} + +type ( + UnlockFileParameterCtx struct{} + UnlockFileService struct { + Tokens []string `json:"tokens" binding:"required,max=16384"` } +) + +func (s *UnlockFileService) Unlock(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - return serializer.Response{ - Code: 0, + // Unlock file + if err := m.Unlock(c, s.Tokens...); err != nil { + return serializer.NewError(serializer.CodeParamErr, "failed to unlock file", err) } + + return nil } -// Sources 批量获取对象的外链 -func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer.Response { - fs, err := filesystem.NewFileSystemFromContext(c) +type ( + GetFileInfoParameterCtx struct{} + GetFileInfoService struct { + Uri string `form:"uri" binding:"required"` + ExtendedInfo bool `form:"extended"` + FolderSummary bool `form:"folder_summary"` + } +) + +func (s *GetFileInfoService) Get(c *gin.Context) (*FileResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(s.Uri) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - defer fs.Recycle() - if len(s.Raw().Items) > fs.User.Group.OptionsSerialized.SourceBatchSize { - return serializer.Err(serializer.CodeBatchSourceSize, "", err) + opts := []fs.Option{dbfs.WithFilePublicMetadata()} + if s.ExtendedInfo { + opts = append(opts, dbfs.WithExtendedInfo(), dbfs.WithEntityUser(), dbfs.WithFileShareIfOwned()) + } + if s.FolderSummary { + opts = append(opts, dbfs.WithLoadFolderSummary()) } - res := make([]serializer.Sources, 0, len(s.Raw().Items)) - files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID) - if err != nil || len(files) == 0 { - return serializer.Err(serializer.CodeFileNotFound, "", err) + file, err := m.Get(c, uri, opts...) + if err != nil { + return nil, fmt.Errorf("failed to get file: %w", err) } - getSourceFunc := func(file model.File) (string, error) { - fs.FileTarget = []model.File{file} - return fs.GetSource(ctx, file.ID) + if file == nil { + return nil, serializer.NewError(serializer.CodeNotFound, "file not found", nil) } - // Create redirected source link if needed - if fs.User.Group.OptionsSerialized.RedirectedSource { - getSourceFunc = func(file model.File) (string, error) { - source, err := file.CreateOrGetSourceLink() - if err != nil { - return "", err - } + return BuildFileResponse(c, user, file, dep.HashIDEncoder(), nil), nil +} - sourceLinkURL, err := source.Link() - if err != nil { - return "", err - } +func RedirectDirectLink(c *gin.Context, name string) error { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() - return sourceLinkURL, nil - } + sourceLinkID := hashid.FromContext(c) + ctx := context.WithValue(c, inventory.LoadDirectLinkFile{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true) + ctx = context.WithValue(ctx, inventory.LoadFileUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + dl, err := dep.DirectLinkClient().GetByNameID(ctx, sourceLinkID, name) + if err != nil { + return serializer.NewError(serializer.CodeNotFound, "direct link not found", err) } - for _, file := range files { - sourceURL, err := getSourceFunc(file) - current := serializer.Sources{ - URL: sourceURL, - Name: file.Name, - Parent: file.FolderID, - } - - if err != nil { - current.Error = err.Error() - } + m := manager.NewFileManager(dep, dl.Edges.File.Edges.Owner) + defer m.Recycle() - res = append(res, current) + // Request entity URL + expire := time.Now().Add(settings.EntityUrlValidDuration(c)) + res, earliestExpire, err := m.GetUrlForRedirectedDirectLink(c, dl, + fs.WithUrlExpire(&expire), + ) + if err != nil { + return err } - return serializer.Response{ - Code: 0, - Data: res, - } + c.Redirect(http.StatusFound, res) + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", int(earliestExpire.Sub(time.Now()).Seconds()))) + return nil } diff --git a/service/explorer/metadata.go b/service/explorer/metadata.go new file mode 100644 index 00000000..e6ac3476 --- /dev/null +++ b/service/explorer/metadata.go @@ -0,0 +1,37 @@ +package explorer + +import ( + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" +) + +type ( + PatchMetadataService struct { + Uris []string `json:"uris" binding:"required"` + Patches []fs.MetadataPatch `json:"patches" binding:"required,dive"` + } + + PatchMetadataParameterCtx struct{} +) + +func (s *PatchMetadataService) GetUris() []string { + return s.Uris +} + +func (s *PatchMetadataService) Patch(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uris, err := fs.NewUriFromStrings(s.Uris...) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + return m.PatchMedata(c, uris, s.Patches...) +} diff --git a/service/explorer/objects.go b/service/explorer/objects.go deleted file mode 100644 index 1c3c45a2..00000000 --- a/service/explorer/objects.go +++ /dev/null @@ -1,467 +0,0 @@ -package explorer - -import ( - "context" - "encoding/gob" - "fmt" - "math" - "path" - "strings" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" -) - -// ItemMoveService 处理多文件/目录移动 -type ItemMoveService struct { - SrcDir string `json:"src_dir" binding:"required,min=1,max=65535"` - Src ItemIDService `json:"src"` - Dst string `json:"dst" binding:"required,min=1,max=65535"` -} - -// ItemRenameService 处理多文件/目录重命名 -type ItemRenameService struct { - Src ItemIDService `json:"src"` - NewName string `json:"new_name" binding:"required,min=1,max=255"` -} - -// ItemService 处理多文件/目录相关服务 -type ItemService struct { - Items []uint `json:"items"` - Dirs []uint `json:"dirs"` -} - -// ItemIDService 处理多文件/目录相关服务,字段值为HashID,可通过Raw()方法获取原始ID -type ItemIDService struct { - Items []string `json:"items"` - Dirs []string `json:"dirs"` - Source *ItemService - Force bool `json:"force"` - UnlinkOnly bool `json:"unlink"` -} - -// ItemCompressService 文件压缩任务服务 -type ItemCompressService struct { - Src ItemIDService `json:"src"` - Dst string `json:"dst" binding:"required,min=1,max=65535"` - Name string `json:"name" binding:"required,min=1,max=255"` -} - -// ItemDecompressService 文件解压缩任务服务 -type ItemDecompressService struct { - Src string `json:"src"` - Dst string `json:"dst" binding:"required,min=1,max=65535"` - Encoding string `json:"encoding"` -} - -// ItemPropertyService 获取对象属性服务 -type ItemPropertyService struct { - ID string `binding:"required"` - TraceRoot bool `form:"trace_root"` - IsFolder bool `form:"is_folder"` -} - -func init() { - gob.Register(ItemIDService{}) -} - -// Raw 批量解码HashID,获取原始ID -func (service *ItemIDService) Raw() *ItemService { - if service.Source != nil { - return service.Source - } - - service.Source = &ItemService{ - Dirs: make([]uint, 0, len(service.Dirs)), - Items: make([]uint, 0, len(service.Items)), - } - for _, folder := range service.Dirs { - id, err := hashid.DecodeHashID(folder, hashid.FolderID) - if err == nil { - service.Source.Dirs = append(service.Source.Dirs, id) - } - } - for _, file := range service.Items { - id, err := hashid.DecodeHashID(file, hashid.FileID) - if err == nil { - service.Source.Items = append(service.Source.Items, id) - } - } - - return service.Source -} - -// CreateDecompressTask 创建文件解压缩任务 -func (service *ItemDecompressService) CreateDecompressTask(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.ArchiveTask { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - // 存放目录是否存在 - if exist, _ := fs.IsPathExist(service.Dst); !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - - // 压缩包是否存在 - exist, file := fs.IsFileExist(service.Src) - if !exist { - return serializer.Err(serializer.CodeFileNotFound, "", nil) - } - - // 文件尺寸限制 - if fs.User.Group.OptionsSerialized.DecompressSize != 0 && file.Size > fs.User.Group. - OptionsSerialized.DecompressSize { - return serializer.Err(serializer.CodeFileTooLarge, "", nil) - } - - // 支持的压缩格式后缀 - var ( - suffixes = []string{".zip", ".gz", ".xz", ".tar", ".rar"} - matched bool - ) - for _, suffix := range suffixes { - if strings.HasSuffix(file.Name, suffix) { - matched = true - break - } - } - if !matched { - return serializer.Err(serializer.CodeUnsupportedArchiveType, "", nil) - } - - // 创建任务 - job, err := task.NewDecompressTask(fs.User, service.Src, service.Dst, service.Encoding) - if err != nil { - return serializer.Err(serializer.CodeCreateTaskError, "", err) - } - task.TaskPoll.Submit(job) - - return serializer.Response{} - -} - -// CreateCompressTask 创建文件压缩任务 -func (service *ItemCompressService) CreateCompressTask(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.ArchiveTask { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - // 补齐压缩文件扩展名(如果没有) - if !strings.HasSuffix(service.Name, ".zip") { - service.Name += ".zip" - } - - // 存放目录是否存在,是否重名 - if exist, _ := fs.IsPathExist(service.Dst); !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - if exist, _ := fs.IsFileExist(path.Join(service.Dst, service.Name)); exist { - return serializer.ParamErr("File "+service.Name+" already exist", nil) - } - - // 检查文件名合法性 - if !fs.ValidateLegalName(context.Background(), service.Name) { - return serializer.Err(serializer.CodeIllegalObjectName, "", nil) - } - if !fs.ValidateExtension(context.Background(), service.Name) { - return serializer.Err(serializer.CodeFileTypeNotAllowed, "", nil) - } - - // 递归列出待压缩子目录 - folders, err := model.GetRecursiveChildFolder(service.Src.Raw().Dirs, fs.User.ID, true) - if err != nil { - return serializer.DBErr("Failed to list folders", err) - } - - // 列出所有待压缩文件 - files, err := model.GetChildFilesOfFolders(&folders) - if err != nil { - return serializer.DBErr("Failed to list files", err) - } - - // 计算待压缩文件大小 - var totalSize uint64 - for i := 0; i < len(files); i++ { - totalSize += files[i].Size - } - - // 文件尺寸限制 - if fs.User.Group.OptionsSerialized.CompressSize != 0 && totalSize > fs.User.Group. - OptionsSerialized.CompressSize { - return serializer.Err(serializer.CodeFileTooLarge, "", nil) - } - - // 按照平均压缩率计算用户空间是否足够 - compressRatio := 0.4 - spaceNeeded := uint64(math.Round(float64(totalSize) * compressRatio)) - if fs.User.GetRemainingCapacity() < spaceNeeded { - return serializer.Err(serializer.CodeInsufficientCapacity, "", err) - } - - // 创建任务 - job, err := task.NewCompressTask(fs.User, path.Join(service.Dst, service.Name), service.Src.Raw().Dirs, - service.Src.Raw().Items) - if err != nil { - return serializer.Err(serializer.CodeCreateTaskError, "", err) - } - task.TaskPoll.Submit(job) - - return serializer.Response{} - -} - -// Archive 创建归档 -func (service *ItemIDService) Archive(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.ArchiveDownload { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - // 创建打包下载会话 - ttl := model.GetIntSetting("archive_timeout", 30) - downloadSessionID := util.RandStringRunes(16) - cache.Set("archive_"+downloadSessionID, *service, ttl) - cache.Set("archive_user_"+downloadSessionID, *fs.User, ttl) - signURL, err := auth.SignURI( - auth.General, - fmt.Sprintf("/api/v3/file/archive/%s/archive.zip", downloadSessionID), - int64(ttl), - ) - - return serializer.Response{ - Code: 0, - Data: signURL.String(), - } -} - -// Delete 删除对象 -func (service *ItemIDService) Delete(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) - } - defer fs.Recycle() - - force, unlink := false, false - if fs.User.Group.OptionsSerialized.AdvanceDelete { - force = service.Force - unlink = service.UnlinkOnly - } - - // 删除对象 - items := service.Raw() - err = fs.Delete(ctx, items.Dirs, items.Items, force, unlink) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - } - -} - -// Move 移动对象 -func (service *ItemMoveService) Move(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 移动对象 - items := service.Src.Raw() - err = fs.Move(ctx, items.Dirs, items.Items, service.SrcDir, service.Dst) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - } - -} - -// Copy 复制对象 -func (service *ItemMoveService) Copy(ctx context.Context, c *gin.Context) serializer.Response { - // 复制操作只能对一个目录或文件对象进行操作 - if len(service.Src.Items)+len(service.Src.Dirs) > 1 { - return filesystem.ErrOneObjectOnly - } - - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 复制对象 - err = fs.Copy(ctx, service.Src.Raw().Dirs, service.Src.Raw().Items, service.SrcDir, service.Dst) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - } - -} - -// Rename 重命名对象 -func (service *ItemRenameService) Rename(ctx context.Context, c *gin.Context) serializer.Response { - // 重命名作只能对一个目录或文件对象进行操作 - if len(service.Src.Items)+len(service.Src.Dirs) > 1 { - return filesystem.ErrOneObjectOnly - } - - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 重命名对象 - err = fs.Rename(ctx, service.Src.Raw().Dirs, service.Src.Raw().Items, service.NewName) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - } -} - -// GetProperty 获取对象的属性 -func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Context) serializer.Response { - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) - - var props serializer.ObjectProps - props.QueryDate = time.Now() - - // 如果是文件对象 - if !service.IsFolder { - res, err := hashid.DecodeHashID(service.ID, hashid.FileID) - if err != nil { - return serializer.Err(serializer.CodeNotFound, "", err) - } - - file, err := model.GetFilesByIDs([]uint{res}, user.ID) - if err != nil { - return serializer.DBErr("Failed to query file records", err) - } - - props.CreatedAt = file[0].CreatedAt - props.UpdatedAt = file[0].UpdatedAt - props.Policy = file[0].GetPolicy().Name - props.Size = file[0].Size - - // 查找父目录 - if service.TraceRoot { - parent, err := model.GetFoldersByIDs([]uint{file[0].FolderID}, user.ID) - if err != nil { - return serializer.DBErr("Parent folder record not exist", err) - } - - if err := parent[0].TraceRoot(); err != nil { - return serializer.DBErr("Failed to trace root folder", err) - } - - props.Path = path.Join(parent[0].Position, parent[0].Name) - } - } else { - res, err := hashid.DecodeHashID(service.ID, hashid.FolderID) - if err != nil { - return serializer.Err(serializer.CodeNotFound, "", err) - } - - folder, err := model.GetFoldersByIDs([]uint{res}, user.ID) - if err != nil { - return serializer.DBErr("Failed to query folder records", err) - } - - props.CreatedAt = folder[0].CreatedAt - props.UpdatedAt = folder[0].UpdatedAt - - // 如果对象是目录, 先尝试返回缓存结果 - if cacheRes, ok := cache.Get(fmt.Sprintf("folder_props_%d", res)); ok { - res := cacheRes.(serializer.ObjectProps) - res.CreatedAt = props.CreatedAt - res.UpdatedAt = props.UpdatedAt - return serializer.Response{Data: res} - } - - // 统计子目录 - childFolders, err := model.GetRecursiveChildFolder([]uint{folder[0].ID}, - user.ID, true) - if err != nil { - return serializer.DBErr("Failed to list child folders", err) - } - props.ChildFolderNum = len(childFolders) - 1 - - // 统计子文件 - files, err := model.GetChildFilesOfFolders(&childFolders) - if err != nil { - return serializer.DBErr("Failed to list child files", err) - } - - // 统计子文件个数和大小 - props.ChildFileNum = len(files) - for i := 0; i < len(files); i++ { - props.Size += files[i].Size - } - - // 查找父目录 - if service.TraceRoot { - if err := folder[0].TraceRoot(); err != nil { - return serializer.DBErr("Failed to list child folders", err) - } - - props.Path = folder[0].Position - } - - // 如果列取对象是目录,则缓存结果 - cache.Set(fmt.Sprintf("folder_props_%d", res), props, - model.GetIntSetting("folder_props_timeout", 300)) - } - - return serializer.Response{ - Code: 0, - Data: props, - } -} diff --git a/service/explorer/pin.go b/service/explorer/pin.go new file mode 100644 index 00000000..f7c2e20d --- /dev/null +++ b/service/explorer/pin.go @@ -0,0 +1,75 @@ +package explorer + +import ( + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" + "github.com/samber/lo" +) + +type ( + PinFileService struct { + Uri string `json:"uri" binding:"required"` + Name string `json:"name"` + } + PinFileParameterCtx struct{} +) + +// PinFileService pins new uri to sidebar +func (service *PinFileService) PinFile(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + userClient := dep.UserClient() + + uri, err := fs.NewUriFromString(service.Uri) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + uriStr := uri.String() + for _, pin := range user.Settings.Pined { + if pin.Uri == uriStr { + if pin.Name != service.Name { + return serializer.NewError(serializer.CodeObjectExist, "uri already pinned with different name", nil) + } + + return nil + } + } + + user.Settings.Pined = append(user.Settings.Pined, types.PinedFile{ + Uri: uriStr, + Name: service.Name, + }) + if err := userClient.SaveSettings(c, user); err != nil { + return serializer.NewError(serializer.CodeDBError, "failed to save settings", err) + } + + return nil +} + +// UnpinFile removes uri from sidebar +func (service *PinFileService) UnpinFile(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + userClient := dep.UserClient() + + uri, err := fs.NewUriFromString(service.Uri) + if err != nil { + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + uriStr := uri.String() + user.Settings.Pined = lo.Filter(user.Settings.Pined, func(pin types.PinedFile, index int) bool { + return pin.Uri != uriStr + }) + + if err := userClient.SaveSettings(c, user); err != nil { + return serializer.NewError(serializer.CodeDBError, "failed to save settings", err) + } + + return nil +} diff --git a/service/explorer/response.go b/service/explorer/response.go new file mode 100644 index 00000000..526b5b12 --- /dev/null +++ b/service/explorer/response.go @@ -0,0 +1,436 @@ +package explorer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/cloudreve/Cloudreve/v4/service/user" + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "github.com/samber/lo" +) + +type DirectLinkResponse struct { + Link string `json:"link"` + FileUrl string `json:"file_url"` +} + +func BuildDirectLinkResponse(links []manager.DirectLink) []DirectLinkResponse { + if len(links) == 0 { + return nil + } + + var res []DirectLinkResponse + for _, link := range links { + res = append(res, DirectLinkResponse{ + Link: link.Url, + FileUrl: link.File.Uri(false).String(), + }) + } + return res +} + +const PathMyRedacted = "redacted" + +type TaskResponse struct { + CreatedAt time.Time `json:"created_at,"` + UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Status string `json:"status"` + Type string `json:"type"` + Node *user.Node `json:"node,omitempty"` + Summary *queue.Summary `json:"summary,omitempty"` + Error string `json:"error,omitempty"` + ErrorHistory []string `json:"error_history,omitempty"` + Duration int64 `json:"duration,omitempty"` + ResumeTime int64 `json:"resume_time,omitempty"` + RetryCount int `json:"retry_count,omitempty"` +} + +type TaskListResponse struct { + Tasks []TaskResponse `json:"tasks"` + Pagination *inventory.PaginationResults `json:"pagination"` +} + +func BuildTaskListResponse(tasks []queue.Task, res *inventory.ListTaskResult, nodeMap map[int]*ent.Node, hasher hashid.Encoder) *TaskListResponse { + return &TaskListResponse{ + Pagination: res.PaginationResults, + Tasks: lo.Map(tasks, func(t queue.Task, index int) TaskResponse { + var ( + node *ent.Node + s = t.Summarize(hasher) + ) + + if s.NodeID > 0 { + node = nodeMap[s.NodeID] + } + return *BuildTaskResponse(t, node, hasher) + }), + } +} + +func BuildTaskResponse(task queue.Task, node *ent.Node, hasher hashid.Encoder) *TaskResponse { + model := task.Model() + t := &TaskResponse{ + Status: string(task.Status()), + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + ID: hashid.EncodeTaskID(hasher, task.ID()), + Type: task.Type(), + Summary: task.Summarize(hasher), + Error: auth.RedactSensitiveValues(model.PublicState.Error), + ErrorHistory: lo.Map(model.PublicState.ErrorHistory, func(s string, index int) string { + return auth.RedactSensitiveValues(s) + }), + Duration: model.PublicState.ExecutedDuration.Milliseconds(), + ResumeTime: model.PublicState.ResumeTime, + RetryCount: model.PublicState.RetryCount, + } + + if node != nil { + t.Node = user.BuildNode(node, hasher) + } + + return t +} + +type UploadSessionResponse struct { + SessionID string `json:"session_id"` + UploadID string `json:"upload_id"` + ChunkSize int64 `json:"chunk_size"` // 分块大小,0 为部分快 + Expires int64 `json:"expires"` // 上传凭证过期时间, Unix 时间戳 + UploadURLs []string `json:"upload_urls,omitempty"` + Credential string `json:"credential,omitempty"` + AccessKey string `json:"ak,omitempty"` + KeyTime string `json:"keyTime,omitempty"` // COS用有效期 + CompleteURL string `json:"completeURL,omitempty"` + StoragePolicy *StoragePolicy `json:"storage_policy,omitempty"` + Uri string `json:"uri"` + CallbackSecret string `json:"callback_secret"` + MimeType string `json:"mime_type,omitempty"` + UploadPolicy string `json:"upload_policy,omitempty"` +} + +func BuildUploadSessionResponse(session *fs.UploadCredential, hasher hashid.Encoder) *UploadSessionResponse { + return &UploadSessionResponse{ + SessionID: session.SessionID, + ChunkSize: session.ChunkSize, + Expires: session.Expires, + UploadURLs: session.UploadURLs, + Credential: session.Credential, + CompleteURL: session.CompleteURL, + Uri: session.Uri, + UploadID: session.UploadID, + StoragePolicy: BuildStoragePolicy(session.StoragePolicy, hasher), + CallbackSecret: session.CallbackSecret, + MimeType: session.MimeType, + UploadPolicy: session.UploadPolicy, + } +} + +// WopiFileInfo Response for `CheckFileInfo` +type WopiFileInfo struct { + // Required + BaseFileName string + Version string + Size int64 + + // Breadcrumb + BreadcrumbBrandName string + BreadcrumbBrandUrl string + BreadcrumbFolderName string + BreadcrumbFolderUrl string + + // Post Message + FileSharingPostMessage bool + FileVersionPostMessage bool + ClosePostMessage bool + PostMessageOrigin string + + // Other miscellaneous properties + FileNameMaxLength int + LastModifiedTime string + + // User metadata + IsAnonymousUser bool + UserFriendlyName string + UserId string + OwnerId string + + // Permission + ReadOnly bool + UserCanRename bool + UserCanReview bool + UserCanWrite bool + + SupportsRename bool + SupportsReviewing bool + SupportsUpdate bool + SupportsLocks bool + + EnableShare bool +} + +type ViewerSessionResponse struct { + Session *manager.ViewerSession `json:"session"` + WopiSrc string `json:"wopi_src,omitempty"` +} + +type ListResponse struct { + Files []FileResponse `json:"files"` + Parent FileResponse `json:"parent,omitempty"` + Pagination *inventory.PaginationResults `json:"pagination"` + Props *fs.NavigatorProps `json:"props"` + // ContextHint is used to speed up following operations under this listed directory. + // It persists some intermedia state so that the following request don't need to query database again. + // All the operations under this directory that supports context hint should carry this value in header + // as X-Cr-Context-Hint. + ContextHint *uuid.UUID `json:"context_hint"` + RecursionLimitReached bool `json:"recursion_limit_reached,omitempty"` + MixedType bool `json:"mixed_type"` + SingleFileView bool `json:"single_file_view,omitempty"` + StoragePolicy *StoragePolicy `json:"storage_policy,omitempty"` +} + +type FileResponse struct { + Type int `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Size int64 `json:"size"` + Metadata map[string]string `json:"metadata"` + Path string `json:"path,omitempty"` + Shared bool `json:"shared,omitempty"` + Capability *boolset.BooleanSet `json:"capability,omitempty"` + Owned bool `json:"owned,omitempty"` + PrimaryEntity string `json:"primary_entity,omitempty"` + + FolderSummary *fs.FolderSummary `json:"folder_summary,omitempty"` + ExtendedInfo *ExtendedInfo `json:"extended_info,omitempty"` +} + +type ExtendedInfo struct { + StoragePolicy *StoragePolicy `json:"storage_policy,omitempty"` + StorageUsed int64 `json:"storage_used"` + Shares []Share `json:"shares,omitempty"` + Entities []Entity `json:"entities,omitempty"` +} + +type StoragePolicy struct { + ID string `json:"id"` + Name string `json:"name"` + AllowedSuffix []string `json:"allowed_suffix,omitempty"` + Type types.PolicyType `json:"type"` + MaxSize int64 `json:"max_size"` + Relay bool `json:"relay,omitempty"` +} + +type Entity struct { + ID string `json:"id"` + Size int64 `json:"size"` + Type types.EntityType `json:"type"` + CreatedAt time.Time `json:"created_at"` + StoragePolicy *StoragePolicy `json:"storage_policy,omitempty"` + CreatedBy *user.User `json:"created_by,omitempty"` +} + +type Share struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + RemainDownloads *int `json:"remain_downloads,omitempty"` + Visited int `json:"visited"` + Downloaded int `json:"downloaded,omitempty"` + Expires *time.Time `json:"expires,omitempty"` + Unlocked bool `json:"unlocked"` + SourceType *types.FileType `json:"source_type,omitempty"` + Owner user.User `json:"owner"` + CreatedAt time.Time `json:"created_at,omitempty"` + Expired bool `json:"expired"` + Url string `json:"url"` + + // Only viewable by owner + IsPrivate bool `json:"is_private,omitempty"` + Password string `json:"password,omitempty"` + + // Only viewable if explicitly unlocked by owner + SourceUri string `json:"source_uri,omitempty"` +} + +func BuildShare(s *ent.Share, base *url.URL, hasher hashid.Encoder, requester *ent.User, owner *ent.User, + name string, t types.FileType, unlocked bool) *Share { + redactLevel := user.RedactLevelAnonymous + if !inventory.IsAnonymousUser(requester) { + redactLevel = user.RedactLevelUser + } + res := Share{ + Name: name, + ID: hashid.EncodeShareID(hasher, s.ID), + Unlocked: unlocked, + Owner: user.BuildUserRedacted(owner, redactLevel, hasher), + Expired: inventory.IsShareExpired(s) != nil, + Url: BuildShareLink(s, hasher, base), + CreatedAt: s.CreatedAt, + Visited: s.Views, + SourceType: util.ToPtr(t), + } + + if unlocked { + res.RemainDownloads = s.RemainDownloads + res.Downloaded = s.Downloads + res.Expires = s.Expires + res.Password = s.Password + } + + if requester.ID == owner.ID { + res.IsPrivate = s.Password != "" + } + + return &res +} + +func BuildListResponse(ctx context.Context, u *ent.User, parent fs.File, res *fs.ListFileResult, hasher hashid.Encoder) *ListResponse { + r := &ListResponse{ + Files: lo.Map(res.Files, func(f fs.File, index int) FileResponse { + return *BuildFileResponse(ctx, u, f, hasher, res.Props.Capability) + }), + Pagination: res.Pagination, + Props: res.Props, + ContextHint: res.ContextHint, + RecursionLimitReached: res.RecursionLimitReached, + MixedType: res.MixedType, + SingleFileView: res.SingleFileView, + StoragePolicy: BuildStoragePolicy(res.StoragePolicy, hasher), + } + + if !res.Parent.IsNil() { + r.Parent = *BuildFileResponse(ctx, u, res.Parent, hasher, res.Props.Capability) + } + + return r +} + +func BuildFileResponse(ctx context.Context, u *ent.User, f fs.File, hasher hashid.Encoder, cap *boolset.BooleanSet) *FileResponse { + var owner *ent.User + if f != nil { + owner = f.Owner() + } + + if cap == nil { + cap = f.Capabilities() + } + + res := &FileResponse{ + Type: int(f.Type()), + ID: hashid.EncodeFileID(hasher, f.ID()), + Name: f.DisplayName(), + CreatedAt: f.CreatedAt(), + UpdatedAt: f.UpdatedAt(), + Size: f.Size(), + Metadata: f.Metadata(), + Path: f.Uri(false).String(), + Shared: f.Shared(), + Capability: cap, + Owned: owner == nil || owner.ID == u.ID, + FolderSummary: f.FolderSummary(), + ExtendedInfo: BuildExtendedInfo(ctx, u, f, hasher), + PrimaryEntity: hashid.EncodeEntityID(hasher, f.PrimaryEntityID()), + } + return res +} + +func BuildExtendedInfo(ctx context.Context, u *ent.User, f fs.File, hasher hashid.Encoder) *ExtendedInfo { + extendedInfo := f.ExtendedInfo() + if extendedInfo == nil { + return nil + } + + ext := &ExtendedInfo{ + StoragePolicy: BuildStoragePolicy(extendedInfo.StoragePolicy, hasher), + StorageUsed: extendedInfo.StorageUsed, + Entities: lo.Map(f.Entities(), func(e fs.Entity, index int) Entity { + return BuildEntity(extendedInfo, e, hasher) + }), + } + + dep := dependency.FromContext(ctx) + base := dep.SettingProvider().SiteURL(ctx) + if u.ID == f.OwnerID() { + // Only owner can see the shares settings. + ext.Shares = lo.Map(extendedInfo.Shares, func(s *ent.Share, index int) Share { + return *BuildShare(s, base, hasher, u, u, f.DisplayName(), f.Type(), true) + }) + + } + + return ext +} + +func BuildEntity(extendedInfo *fs.FileExtendedInfo, e fs.Entity, hasher hashid.Encoder) Entity { + var u *user.User + createdBy := e.CreatedBy() + if createdBy != nil { + userRedacted := user.BuildUserRedacted(e.CreatedBy(), user.RedactLevelAnonymous, hasher) + u = &userRedacted + } + return Entity{ + ID: hashid.EncodeEntityID(hasher, e.ID()), + Type: e.Type(), + CreatedAt: e.CreatedAt(), + StoragePolicy: BuildStoragePolicy(extendedInfo.EntityStoragePolicies[e.PolicyID()], hasher), + Size: e.Size(), + CreatedBy: u, + } +} + +func BuildShareLink(s *ent.Share, hasher hashid.Encoder, base *url.URL) string { + shareId := hashid.EncodeShareID(hasher, s.ID) + return routes.MasterShareUrl(base, shareId, s.Password).String() +} + +func BuildStoragePolicy(sp *ent.StoragePolicy, hasher hashid.Encoder) *StoragePolicy { + if sp == nil { + return nil + } + return &StoragePolicy{ + ID: hashid.EncodePolicyID(hasher, sp.ID), + Name: sp.Name, + Type: types.PolicyType(sp.Type), + MaxSize: sp.MaxSize, + AllowedSuffix: sp.Settings.FileType, + Relay: sp.Settings.Relay, + } +} + +func WriteEventSourceHeader(c *gin.Context) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("X-Accel-Buffering", "no") +} + +// WriteEventSource writes a Server-Sent Event to the client. +func WriteEventSource(c *gin.Context, event string, data any) { + c.Writer.Write([]byte(fmt.Sprintf("event: %s\n", event))) + c.Writer.Write([]byte("data:")) + json.NewEncoder(c.Writer).Encode(data) + c.Writer.Write([]byte("\n")) + c.Writer.Flush() +} + +var ErrSSETakeOver = errors.New("SSE take over") diff --git a/service/explorer/search.go b/service/explorer/search.go deleted file mode 100644 index 72b7afaf..00000000 --- a/service/explorer/search.go +++ /dev/null @@ -1,88 +0,0 @@ -package explorer - -import ( - "context" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" -) - -// ItemSearchService 文件搜索服务 -type ItemSearchService struct { - Type string `uri:"type" binding:"required"` - Keywords string `uri:"keywords" binding:"required"` - Path string `form:"path"` -} - -// Search 执行搜索 -func (service *ItemSearchService) Search(c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - if service.Path != "" { - ok, parent := fs.IsPathExist(service.Path) - if !ok { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - - fs.Root = parent - } - - switch service.Type { - case "keywords": - return service.SearchKeywords(c, fs, "%"+service.Keywords+"%") - case "image": - return service.SearchKeywords(c, fs, "%.bmp", "%.iff", "%.png", "%.gif", "%.jpg", "%.jpeg", "%.psd", "%.svg", "%.webp") - case "video": - return service.SearchKeywords(c, fs, "%.mp4", "%.flv", "%.avi", "%.wmv", "%.mkv", "%.rm", "%.rmvb", "%.mov", "%.ogv") - case "audio": - return service.SearchKeywords(c, fs, "%.mp3", "%.flac", "%.ape", "%.wav", "%.acc", "%.ogg", "%.midi", "%.mid") - case "doc": - return service.SearchKeywords(c, fs, "%.txt", "%.md", "%.pdf", "%.doc", "%.docx", "%.ppt", "%.pptx", "%.xls", "%.xlsx", "%.pub") - case "tag": - if tid, err := hashid.DecodeHashID(service.Keywords, hashid.TagID); err == nil { - if tag, err := model.GetTagsByID(tid, fs.User.ID); err == nil { - if tag.Type == model.FileTagType { - exp := strings.Split(tag.Expression, "\n") - expInput := make([]interface{}, len(exp)) - for i := 0; i < len(exp); i++ { - expInput[i] = exp[i] - } - return service.SearchKeywords(c, fs, expInput...) - } - } - } - return serializer.Err(serializer.CodeNotFound, "", nil) - default: - return serializer.ParamErr("Unknown search type", nil) - } -} - -// SearchKeywords 根据关键字搜索文件 -func (service *ItemSearchService) SearchKeywords(c *gin.Context, fs *filesystem.FileSystem, keywords ...interface{}) serializer.Response { - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 获取子项目 - objects, err := fs.Search(ctx, keywords...) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - Data: map[string]interface{}{ - "parent": 0, - "objects": objects, - }, - } -} diff --git a/service/explorer/slave.go b/service/explorer/slave.go index eee840c6..f63e902a 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -1,25 +1,20 @@ package explorer import ( - "context" "encoding/base64" - "encoding/json" "fmt" - "net/http" - "net/url" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/local" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" + "github.com/samber/lo" + "strings" ) // SlaveDownloadService 从机文件下載服务 @@ -46,148 +41,211 @@ type SlaveListService struct { Recursive bool `json:"recursive"` } -// ServeFile 通过签名的URL下载从机文件 -func (service *SlaveDownloadService) ServeFile(ctx context.Context, c *gin.Context, isDownload bool) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() +// SlaveServe serves file content +func (s *EntityDownloadService) SlaveServe(c *gin.Context) error { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + + src, err := base64.URLEncoding.DecodeString(s.Src) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return fmt.Errorf("failed to decode src: %w", err) } - defer fs.Recycle() - // 解码文件路径 - fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) + entity, err := local.NewLocalFileEntity(types.EntityTypeVersion, string(src)) if err != nil { - return serializer.Err(serializer.CodeFileNotFound, "", err) + return fs.ErrPathNotExist.WithError(err) } - // 根据URL里的信息创建一个文件对象和用户对象 - file := model.File{ - Name: service.Name, - SourceName: string(fileSource), - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, + entitySource, err := m.GetEntitySource(c, 0, fs.WithEntity(entity)) + if err != nil { + return fmt.Errorf("failed to get entity source: %w", err) } - fs.User = &model.User{ - Group: model.Group{SpeedLimit: service.Speed}, + + defer entitySource.Close() + + // Set cache header for public resource + settings := dep.SettingProvider() + maxAge := settings.PublicResourceMaxAge(c) + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", maxAge)) + + isDownload := c.Query(routes.IsDownloadQuery) != "" + entitySource.Serve(c.Writer, c.Request, + entitysource.WithSpeedLimit(s.SpeedLimit), + entitysource.WithDownload(isDownload), + entitysource.WithDisplayName(s.Name), + entitysource.WithContext(c), + ) + return nil +} + +type ( + SlaveCreateUploadSessionParamCtx struct{} + // SlaveCreateUploadSessionService 从机上传会话服务 + SlaveCreateUploadSessionService struct { + Session fs.UploadSession `json:"session" binding:"required"` + Overwrite bool `json:"overwrite"` } - fs.FileTarget = []model.File{file} +) - // 开始处理下载 - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - rs, err := fs.GetDownloadContent(ctx, 0) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) +// Create 从机创建上传会话 +func (service *SlaveCreateUploadSessionService) Create(c *gin.Context) error { + mode := fs.ModeNone + if service.Overwrite { + mode = fs.ModeOverwrite } - defer rs.Close() - // 设置下载文件名 - if isDownload { - c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") + req := &fs.UploadRequest{ + Mode: mode, + Props: service.Session.Props.Copy(), } - // 发送文件 - http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, time.Now(), rs) + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + _, err := m.CreateUploadSession(c, req, fs.WithUploadSession(&service.Session)) + if err != nil { + return serializer.NewError(serializer.CodeCacheOperation, "Failed to create upload session in slave node", err) + } - return serializer.Response{} + return nil } -// Delete 通过签名的URL删除从机文件 -func (service *SlaveFilesService) Delete(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) +type ( + SlaveMetaParamCtx struct{} + SlaveMetaService struct { + Src string `uri:"src" binding:"required"` + Ext string `uri:"ext" binding:"required"` } - defer fs.Recycle() +) - // 删除文件 - failed, err := fs.Handler.Delete(ctx, service.Files) +// MediaMeta retrieves media metadata +func (s *SlaveMetaService) MediaMeta(c *gin.Context) ([]driver.MediaMeta, error) { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + + src, err := base64.URLEncoding.DecodeString(s.Src) if err != nil { - // 将Data字段写为字符串方便主控端解析 - data, _ := json.Marshal(serializer.RemoteDeleteRequest{Files: failed}) - - return serializer.Response{ - Code: serializer.CodeNotFullySuccess, - Data: string(data), - Msg: fmt.Sprintf("Failed to delete %d files(s)", len(failed)), - Error: err.Error(), - } + return nil, fmt.Errorf("failed to decode src: %w", err) } - return serializer.Response{} -} -// Thumb 通过签名URL获取从机文件缩略图 -func (service *SlaveFileService) Thumb(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() + entity, err := local.NewLocalFileEntity(types.EntityTypeVersion, string(src)) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, fs.ErrPathNotExist.WithError(err) } - defer fs.Recycle() - // 解码文件路径 - fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded) + entitySource, err := m.GetEntitySource(c, 0, fs.WithEntity(entity)) if err != nil { - return serializer.Err(serializer.CodeFileNotFound, "", err) + return nil, fmt.Errorf("failed to get entity source: %w", err) } - fs.FileTarget = []model.File{{SourceName: string(fileSource), Name: fmt.Sprintf("%s.%s", fileSource, service.Ext), PicInfo: "1,1"}} + defer entitySource.Close() - // 获取缩略图 - resp, err := fs.GetThumb(ctx, 0) + extractor := dep.MediaMetaExtractor(c) + res, err := extractor.Extract(c, s.Ext, entitySource) if err != nil { - return serializer.Err(serializer.CodeNotSet, "Failed to get thumb", err) + return nil, fmt.Errorf("failed to extract media meta: %w", err) } - defer resp.Content.Close() - http.ServeContent(c.Writer, c.Request, "thumb.png", time.Now(), resp.Content) - - return serializer.Response{} + return res, nil } -// CreateTransferTask 创建从机文件转存任务 -func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serializer.Response { - if id, ok := c.Get("MasterSiteID"); ok { - job := &slavetask.TransferTask{ - Req: req, - MasterID: id.(string), +type ( + SlaveThumbParamCtx struct{} + SlaveThumbService struct { + Src string `uri:"src" binding:"required"` + Ext string `uri:"ext" binding:"required"` + } +) + +func (s *SlaveThumbService) Thumb(c *gin.Context) error { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + + src, err := base64.URLEncoding.DecodeString(s.Src) + if err != nil { + return fmt.Errorf("failed to decode src: %w", err) + } + + settings := dep.SettingProvider() + var entity fs.Entity + entity, err = local.NewLocalFileEntity(types.EntityTypeThumbnail, string(src)+settings.ThumbSlaveSidecarSuffix(c)) + if err != nil { + srcEntity, err := local.NewLocalFileEntity(types.EntityTypeVersion, string(src)) + if err != nil { + return fs.ErrPathNotExist.WithError(err) } - if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { - task.TaskPoll.Submit(job.(task.Job)) - }); err != nil { - return serializer.Err(serializer.CodeCreateTaskError, "", err) + entity, err = m.SubmitAndAwaitThumbnailTask(c, nil, s.Ext, srcEntity) + if err != nil { + return fmt.Errorf("failed to submit and await thumbnail task: %w", err) } + } - return serializer.Response{} + entitySource, err := m.GetEntitySource(c, 0, fs.WithEntity(entity)) + if err != nil { + return fmt.Errorf("failed to get thumb entity source: %w", err) } - return serializer.ParamErr("未知的主机节点ID", nil) + defer entitySource.Close() + + // Set cache header for public resource + maxAge := settings.PublicResourceMaxAge(c) + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", maxAge)) + + entitySource.Serve(c.Writer, c.Request, + entitysource.WithContext(c), + ) + return nil } -// SlaveListService 从机上传会话服务 -type SlaveCreateUploadSessionService struct { - Session serializer.UploadSession `json:"session" binding:"required"` - TTL int64 `json:"ttl"` - Overwrite bool `json:"overwrite"` +type ( + SlaveDeleteUploadSessionParamCtx struct{} + SlaveDeleteUploadSessionService struct { + ID string `uri:"sessionId" binding:"required"` + } +) + +// Delete deletes an upload session from slave node +func (service *SlaveDeleteUploadSessionService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + + err := m.CancelUploadSession(c, nil, service.ID) + if err != nil { + return fmt.Errorf("slave failed to delete upload session: %w", err) + } + + return nil } -// Create 从机创建上传会话 -func (service *SlaveCreateUploadSessionService) Create(ctx context.Context, c *gin.Context) serializer.Response { - if !service.Overwrite && util.Exists(service.Session.SavePath) { - return serializer.Err(serializer.CodeConflict, "placeholder file already exist", nil) +type ( + SlaveDeleteFileParamCtx struct{} + SlaveDeleteFileService struct { + Files []string `json:"files" binding:"required,gt=0"` } +) - err := cache.Set( - filesystem.UploadSessionCachePrefix+service.Session.Key, - service.Session, - int(service.TTL), - ) +func (service *SlaveDeleteFileService) Delete(c *gin.Context) ([]string, error) { + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + d := m.LocalDriver(nil) + + // Try to delete thumbnail sidecar + sidecarSuffix := dep.SettingProvider().ThumbSlaveSidecarSuffix(c) + failed, err := d.Delete(c, lo.Map(service.Files, func(item string, index int) string { + return item + sidecarSuffix + })...) + if err != nil { + dep.Logger().Warning("Failed to delete thumbnail sidecar [%s]: %s", strings.Join(failed, ", "), err) + } + + failed, err = d.Delete(c, service.Files...) if err != nil { - return serializer.Err(serializer.CodeCacheOperation, "Failed to create upload session in slave node", err) + return failed, fmt.Errorf("slave failed to delete file: %w", err) } - return serializer.Response{} + return nil, nil } diff --git a/service/explorer/tag.go b/service/explorer/tag.go deleted file mode 100644 index 02e324f1..00000000 --- a/service/explorer/tag.go +++ /dev/null @@ -1,88 +0,0 @@ -package explorer - -import ( - "fmt" - "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" -) - -// FilterTagCreateService 文件分类标签创建服务 -type FilterTagCreateService struct { - Expression string `json:"expression" binding:"required,min=1,max=65535"` - Icon string `json:"icon" binding:"required,min=1,max=255"` - Name string `json:"name" binding:"required,min=1,max=255"` - Color string `json:"color" binding:"hexcolor|rgb|rgba|hsl"` -} - -// LinkTagCreateService 目录快捷方式标签创建服务 -type LinkTagCreateService struct { - Path string `json:"path" binding:"required,min=1,max=65535"` - Name string `json:"name" binding:"required,min=1,max=255"` -} - -// TagService 标签服务 -type TagService struct { -} - -// Delete 删除标签 -func (service *TagService) Delete(c *gin.Context, user *model.User) serializer.Response { - id, _ := c.Get("object_id") - if err := model.DeleteTagByID(id.(uint), user.ID); err != nil { - return serializer.DBErr("Failed to delete a tag", err) - } - return serializer.Response{} -} - -// Create 创建标签 -func (service *LinkTagCreateService) Create(c *gin.Context, user *model.User) serializer.Response { - // 创建标签 - tag := model.Tag{ - Name: service.Name, - Icon: "FolderHeartOutline", - Type: model.DirectoryLinkType, - Expression: service.Path, - UserID: user.ID, - } - id, err := tag.Create() - if err != nil { - return serializer.DBErr("Failed to create a tag", err) - } - - return serializer.Response{ - Data: hashid.HashID(id, hashid.TagID), - } -} - -// Create 创建标签 -func (service *FilterTagCreateService) Create(c *gin.Context, user *model.User) serializer.Response { - // 分割表达式,将通配符转换为SQL内的% - expressions := strings.Split(service.Expression, "\n") - for i := 0; i < len(expressions); i++ { - expressions[i] = strings.ReplaceAll(expressions[i], "*", "%") - if expressions[i] == "" { - return serializer.ParamErr(fmt.Sprintf("The %d line contains an empty match expression", i+1), nil) - } - } - - // 创建标签 - tag := model.Tag{ - Name: service.Name, - Icon: service.Icon, - Color: service.Color, - Type: model.FileTagType, - Expression: strings.Join(expressions, "\n"), - UserID: user.ID, - } - id, err := tag.Create() - if err != nil { - return serializer.DBErr("Failed to create a tag", err) - } - - return serializer.Response{ - Data: hashid.HashID(id, hashid.TagID), - } -} diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 0c26c26c..54b2a8b5 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -3,167 +3,158 @@ package explorer import ( "context" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" "github.com/gin-gonic/gin" - "io/ioutil" "strconv" - "strings" "time" ) // CreateUploadSessionService 获取上传凭证服务 -type CreateUploadSessionService struct { - Path string `json:"path" binding:"required"` - Size uint64 `json:"size" binding:"min=0"` - Name string `json:"name" binding:"required"` - PolicyID string `json:"policy_id" binding:"required"` - LastModified int64 `json:"last_modified"` - MimeType string `json:"mime_type"` -} +type ( + CreateUploadSessionParameterCtx struct{} + CreateUploadSessionService struct { + Uri string `json:"uri" binding:"required"` + Size int64 `json:"size" binding:"min=0"` + LastModified int64 `json:"last_modified"` + MimeType string `json:"mime_type"` + PolicyID string `json:"policy_id"` + Metadata map[string]string `json:"metadata" binding:"max=256"` + EntityType string `json:"entity_type" binding:"eq=|eq=live_photo|eq=version"` + } +) // Create 创建新的上传会话 -func (service *CreateUploadSessionService) Create(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) +func (service *CreateUploadSessionService) Create(c context.Context) (*UploadSessionResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - // 取得存储策略的ID - rawID, err := hashid.DecodeHashID(service.PolicyID, hashid.PolicyID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) + var entityType *types.EntityType + switch service.EntityType { + case "live_photo": + livePhoto := types.EntityTypeLivePhoto + entityType = &livePhoto + case "version": + version := types.EntityTypeVersion + entityType = &version } - if fs.Policy.ID != rawID { - return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略发生变化,请刷新文件列表并重新添加此任务", nil) + hasher := dep.HashIDEncoder() + policyId, err := hasher.Decode(service.PolicyID, hashid.PolicyID) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "unknown policy id", err) } - file := &fsctx.FileStream{ - Size: service.Size, - Name: service.Name, - VirtualPath: service.Path, - File: ioutil.NopCloser(strings.NewReader("")), - MimeType: service.MimeType, + uploadRequest := &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: uri, + Size: service.Size, + + MimeType: service.MimeType, + Metadata: service.Metadata, + EntityType: entityType, + PreferredStoragePolicy: policyId, + }, } + if service.LastModified > 0 { lastModified := time.UnixMilli(service.LastModified) - file.LastModified = &lastModified + uploadRequest.Props.LastModified = &lastModified } - credential, err := fs.CreateUploadSession(ctx, file) + + credential, err := m.CreateUploadSession(c, uploadRequest) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, err } - return serializer.Response{ - Code: 0, - Data: credential, - } + return BuildUploadSessionResponse(credential, hasher), nil } -// UploadService 本机及从机策略上传服务 -type UploadService struct { - ID string `uri:"sessionId" binding:"required"` - Index int `uri:"index" form:"index" binding:"min=0"` -} +type ( + UploadParameterCtx struct{} + // UploadService 本机及从机策略上传服务 + UploadService struct { + ID string `uri:"sessionId" binding:"required"` + Index int `uri:"index" form:"index" binding:"min=0"` + } +) // LocalUpload 处理本机文件分片上传 -func (service *UploadService) LocalUpload(ctx context.Context, c *gin.Context) serializer.Response { - uploadSessionRaw, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) +func (service *UploadService) LocalUpload(c *gin.Context) error { + dep := dependency.FromContext(c) + kv := dep.KV() + + uploadSessionRaw, ok := kv.Get(manager.UploadSessionCachePrefix + service.ID) if !ok { - return serializer.Err(serializer.CodeUploadSessionExpired, "", nil) + return serializer.NewError(serializer.CodeUploadSessionExpired, "", nil) } - uploadSession := uploadSessionRaw.(serializer.UploadSession) + uploadSession := uploadSessionRaw.(fs.UploadSession) - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) - } + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() - if uploadSession.UID != fs.User.ID { - return serializer.Err(serializer.CodeUploadSessionExpired, "", nil) + if uploadSession.UID != user.ID { + return serializer.NewError(serializer.CodeUploadSessionExpired, "", nil) } - // 查找上传会话创建的占位文件 - file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) + // Confirm upload session and chunk index + placeholder, err := m.ConfirmUploadSession(c, &uploadSession, service.Index) if err != nil { - return serializer.Err(serializer.CodeUploadSessionExpired, "", err) - } - - // 重设 fs 存储策略 - if !uploadSession.Policy.IsTransitUpload(uploadSession.Size) { - return serializer.Err(serializer.CodePolicyNotAllowed, "", err) - } - - fs.Policy = &uploadSession.Policy - if err := fs.DispatchHandler(); err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) - } - - expectedSizeStart := file.Size - actualSizeStart := uint64(service.Index) * uploadSession.Policy.OptionsSerialized.ChunkSize - if uploadSession.Policy.OptionsSerialized.ChunkSize == 0 && service.Index > 0 { - return serializer.Err(serializer.CodeInvalidChunkIndex, "Chunk index cannot be greater than 0", nil) + return err } - if expectedSizeStart < actualSizeStart { - return serializer.Err(serializer.CodeInvalidChunkIndex, "Chunk must be uploaded in order", nil) - } - - if expectedSizeStart > actualSizeStart { - util.Log().Info("Trying to overwrite chunk[%d] Start=%d", service.Index, actualSizeStart) - } - - return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, file, fsctx.Append) + return processChunkUpload(c, m, &uploadSession, service.Index, placeholder, fs.ModeOverwrite) } // SlaveUpload 处理从机文件分片上传 -func (service *UploadService) SlaveUpload(ctx context.Context, c *gin.Context) serializer.Response { - uploadSessionRaw, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) - if !ok { - return serializer.Err(serializer.CodeUploadSessionExpired, "", nil) - } - - uploadSession := uploadSessionRaw.(serializer.UploadSession) +func (service *UploadService) SlaveUpload(c *gin.Context) error { + dep := dependency.FromContext(c) + kv := dep.KV() - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + uploadSessionRaw, ok := kv.Get(manager.UploadSessionCachePrefix + service.ID) + if !ok { + return serializer.NewError(serializer.CodeUploadSessionExpired, "", nil) } - fs.Handler = local.Driver{} + uploadSession := uploadSessionRaw.(fs.UploadSession) - // 解析需要的参数 + // Parse chunk index from query service.Index, _ = strconv.Atoi(c.Query("chunk")) - mode := fsctx.Append - if c.GetHeader(auth.CrHeaderPrefix+"Overwrite") == "true" { - mode |= fsctx.Overwrite - } - return processChunkUpload(ctx, c, fs, &uploadSession, service.Index, nil, mode) + m := manager.NewFileManager(dep, nil) + defer m.Recycle() + + return processChunkUpload(c, m, &uploadSession, service.Index, nil, fs.ModeOverwrite) } -func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.FileSystem, session *serializer.UploadSession, index int, file *model.File, mode fsctx.WriteMode) serializer.Response { +func processChunkUpload(c *gin.Context, m manager.FileManager, session *fs.UploadSession, index int, file fs.File, mode fs.WriteMode) error { // 取得并校验文件大小是否符合分片要求 - chunkSize := session.Policy.OptionsSerialized.ChunkSize - isLastChunk := session.Policy.OptionsSerialized.ChunkSize == 0 || uint64(index+1)*chunkSize >= session.Size + chunkSize := session.ChunkSize + isLastChunk := session.ChunkSize == 0 || int64(index+1)*chunkSize >= session.Props.Size expectedLength := chunkSize if isLastChunk { - expectedLength = session.Size - uint64(index)*chunkSize + expectedLength = session.Props.Size - int64(index)*chunkSize } - fileSize, err := strconv.ParseUint(c.Request.Header.Get("Content-Length"), 10, 64) + rc, fileSize, err := request.SniffContentLength(c.Request) if err != nil || (expectedLength != fileSize) { - return serializer.Err( + return serializer.NewError( serializer.CodeInvalidContentLength, fmt.Sprintf("Invalid Content-Length (expected: %d)", expectedLength), err, @@ -172,121 +163,60 @@ func processChunkUpload(ctx context.Context, c *gin.Context, fs *filesystem.File // 非首个分片时需要允许覆盖 if index > 0 { - mode |= fsctx.Overwrite - } - - fileData := fsctx.FileStream{ - MimeType: c.Request.Header.Get("Content-Type"), - File: c.Request.Body, - Size: fileSize, - Name: session.Name, - VirtualPath: session.VirtualPath, - SavePath: session.SavePath, - Mode: mode, - AppendStart: chunkSize * uint64(index), - Model: file, - LastModified: session.LastModified, + mode |= fs.ModeOverwrite } - // 给文件系统分配钩子 - fs.Use("AfterUploadCanceled", filesystem.HookTruncateFileTo(fileData.AppendStart)) - fs.Use("AfterValidateFailed", filesystem.HookTruncateFileTo(fileData.AppendStart)) - - if file != nil { - fs.Use("BeforeUpload", filesystem.HookValidateCapacity) - fs.Use("AfterUpload", filesystem.HookChunkUploaded) - fs.Use("AfterValidateFailed", filesystem.HookChunkUploadFailed) - if isLastChunk { - fs.Use("AfterUpload", filesystem.HookPopPlaceholderToFile("")) - fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) - } - } else { - if isLastChunk { - fs.Use("AfterUpload", filesystem.SlaveAfterUpload(session)) - fs.Use("AfterUpload", filesystem.HookDeleteUploadSession(session.Key)) - } + req := &fs.UploadRequest{ + File: rc, + Offset: chunkSize * int64(index), + Props: session.Props.Copy(), + Mode: mode, } // 执行上传 - uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) - err = fs.Upload(uploadCtx, &fileData) + ctx := context.WithValue(c, cluster.SlaveNodeIDCtx{}, strconv.Itoa(session.Policy.NodeID)) + err = m.Upload(ctx, req, session.Policy) if err != nil { - return serializer.Err(serializer.CodeUploadFailed, err.Error(), err) + return err } - return serializer.Response{} -} - -// UploadSessionService 上传会话服务 -type UploadSessionService struct { - ID string `uri:"sessionId" binding:"required"` -} - -// Delete 删除指定上传会话 -func (service *UploadSessionService) Delete(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 查找需要删除的上传会话的占位文件 - file, err := model.GetFilesByUploadSession(service.ID, fs.User.ID) - if err != nil { - return serializer.Err(serializer.CodeUploadSessionExpired, "", err) + if rc, ok := req.File.(request.LimitReaderCloser); ok { + if rc.Count() != expectedLength { + err := fmt.Errorf("uploaded data(%d) does not match purposed size(%d)", rc.Count(), req.Props.Size) + return serializer.NewError(serializer.CodeIOFailed, "Uploaded data does not match purposed size", err) + } } - // 删除文件 - if err := fs.Delete(ctx, []uint{}, []uint{file.ID}, false, false); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to delete upload session", err) + // Finish upload + if isLastChunk { + _, err := m.CompleteUpload(ctx, session) + if err != nil { + return fmt.Errorf("failed to complete upload: %w", err) + } } - return serializer.Response{} + return nil } -// SlaveDelete 从机删除指定上传会话 -func (service *UploadSessionService) SlaveDelete(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewAnonymousFileSystem() - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - session, ok := cache.Get(filesystem.UploadSessionCachePrefix + service.ID) - if !ok { - return serializer.Err(serializer.CodeUploadSessionExpired, "", nil) - } - - if _, err := fs.Handler.Delete(ctx, []string{session.(serializer.UploadSession).SavePath}); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to delete temp file", err) +type ( + DeleteUploadSessionParameterCtx struct{} + DeleteUploadSessionService struct { + ID string `json:"id" binding:"required"` + Uri string `json:"uri" binding:"required"` } +) - cache.Deletes([]string{service.ID}, filesystem.UploadSessionCachePrefix) - return serializer.Response{} -} +// Delete deletes the specified upload session +func (service *DeleteUploadSessionService) Delete(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() -// DeleteAllUploadSession 删除当前用户的全部上传绘会话 -func DeleteAllUploadSession(ctx context.Context, c *gin.Context) serializer.Response { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) + uri, err := fs.NewUriFromString(service.Uri) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 查找需要删除的上传会话的占位文件 - files := model.GetUploadPlaceholderFiles(fs.User.ID) - fileIDs := make([]uint, len(files)) - for i, file := range files { - fileIDs[i] = file.ID - } - - // 删除文件 - if err := fs.Delete(ctx, []uint{}, fileIDs, false, false); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to cleanup upload session", err) + return serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - return serializer.Response{} + return m.CancelUploadSession(c, uri, service.ID) } diff --git a/service/explorer/viewer.go b/service/explorer/viewer.go new file mode 100644 index 00000000..f7ede909 --- /dev/null +++ b/service/explorer/viewer.go @@ -0,0 +1,394 @@ +package explorer + +import ( + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager/entitysource" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/wopi" + "github.com/gin-gonic/gin" + "net/http" + "time" +) + +type WopiService struct { +} + +func prepareFs(c *gin.Context) (*fs.URI, manager.FileManager, *ent.User, *manager.ViewerSessionCache, dependency.Dep, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + viewerSession := manager.ViewerSessionFromContext(c) + uri, err := fs.NewUriFromString(viewerSession.Uri) + if err != nil { + return nil, nil, nil, nil, nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + return uri, m, user, viewerSession, dep, nil +} + +func (service *WopiService) Unlock(c *gin.Context) error { + _, m, _, _, dep, err := prepareFs(c) + if err != nil { + return err + } + + l := dep.Logger() + + lockToken := c.GetHeader(wopi.LockTokenHeader) + if err = m.Unlock(c, lockToken); err != nil { + l.Debug("WOPI unlock, not locked or not match: %w", err) + c.Status(http.StatusConflict) + c.Header(wopi.LockTokenHeader, "") + return nil + } + + return nil +} + +func (service *WopiService) RefreshLock(c *gin.Context) error { + uri, m, _, _, dep, err := prepareFs(c) + if err != nil { + return err + } + + l := dep.Logger() + + // Make sure file exists and readable + file, err := m.Get(c, uri, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityLockFile)) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + lockToken := c.GetHeader(wopi.LockTokenHeader) + release, _, err := m.ConfirmLock(c, file, file.Uri(false), lockToken) + if err != nil { + // File not locked for token not match + + l.Debug("WOPI refresh lock, not locked or not match: %w", err) + c.Status(http.StatusConflict) + c.Header(wopi.LockTokenHeader, "") + return nil + + } + + // refresh lock + release() + _, err = m.Refresh(c, wopi.LockDuration, lockToken) + if err != nil { + return err + } + + c.Header(wopi.LockTokenHeader, lockToken) + return nil +} + +func (service *WopiService) Lock(c *gin.Context) error { + uri, m, user, viewerSession, dep, err := prepareFs(c) + if err != nil { + return err + } + + l := dep.Logger() + + // Make sure file exists and readable + file, err := m.Get(c, uri, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityLockFile)) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + lockToken := c.GetHeader(wopi.LockTokenHeader) + release, _, err := m.ConfirmLock(c, file, file.Uri(false), lockToken) + if err != nil { + // File not locked for token not match + + // Try to lock using given token + app := lock.Application{ + Type: string(fs.ApplicationViewer), + ViewerID: viewerSession.ViewerID, + } + _, err = m.Lock(c, wopi.LockDuration, user, true, app, file.Uri(false), lockToken) + if err != nil { + // Token not match + var lockConflict lock.ConflictError + if errors.As(err, &lockConflict) { + c.Status(http.StatusConflict) + c.Header(wopi.LockTokenHeader, lockConflict[0].Token) + + l.Debug("WOPI lock, lock conflict: %w", err) + return nil + } + + return fmt.Errorf("failed to lock file: %w", err) + } + + // Lock success, return the token + c.Header(wopi.LockTokenHeader, lockToken) + return nil + + } + + // refresh lock + release() + _, err = m.Refresh(c, wopi.LockDuration, lockToken) + if err != nil { + return err + } + + c.Header(wopi.LockTokenHeader, lockToken) + return nil +} + +func (service *WopiService) PutContent(c *gin.Context) error { + uri, m, user, viewerSession, _, err := prepareFs(c) + if err != nil { + return err + } + + // Make sure file exists and readable + file, err := m.Get(c, uri, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityUploadFile)) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + var lockSession fs.LockSession + lockToken := c.GetHeader(wopi.LockTokenHeader) + if lockToken != "" { + // File not locked for token not match + + release, ls, err := m.ConfirmLock(c, file, file.Uri(false), lockToken) + if err != nil { + // File not locked for token not match + + // Try to lock using given token + app := lock.Application{ + Type: string(fs.ApplicationViewer), + ViewerID: viewerSession.ViewerID, + } + ls, err := m.Lock(c, wopi.LockDuration, user, true, app, file.Uri(false), lockToken) + if err != nil { + // Token not match + // If the file is currently locked and the X-WOPI-Lock value doesn't match the lock currently on the file, the host must + // + // Return a lock mismatch response (409 Conflict) + // Include an X-WOPI-Lock response header containing the value of the current lock on the file. + var lockConflict lock.ConflictError + if errors.As(err, &lockConflict) { + c.Status(http.StatusConflict) + c.Header(wopi.LockTokenHeader, lockConflict[0].Token) + + return nil + } + + return fmt.Errorf("failed to lock file: %w", err) + } + + // In cases where the file is unlocked, the host must set X-WOPI-Lock to the empty string. + c.Header(wopi.LockTokenHeader, "") + _ = m.Unlock(c, ls.LastToken()) + } else { + defer release() + } + + lockSession = ls + } + + subService := FileUpdateService{ + Uri: viewerSession.Uri, + } + + res, err := subService.PutContent(c, lockSession) + if err != nil { + var appErr serializer.AppError + if errors.As(err, &appErr) { + switch appErr.Code { + case serializer.CodeFileTooLarge: + c.Status(http.StatusRequestEntityTooLarge) + c.Header(wopi.ServerErrorHeader, err.Error()) + case serializer.CodeNotFound: + c.Status(http.StatusNotFound) + c.Header(wopi.ServerErrorHeader, err.Error()) + case 0: + c.Status(http.StatusOK) + default: + return err + } + + return nil + } + + return err + } + + c.Header(wopi.ItemVersionHeader, res.PrimaryEntity) + return nil +} + +func (service *WopiService) GetFile(c *gin.Context) error { + uri, m, _, viewerSession, dep, err := prepareFs(c) + if err != nil { + return err + } + + // Make sure file exists and readable + file, err := m.Get(c, uri, dbfs.WithExtendedInfo(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return fmt.Errorf("failed to get file: %w", err) + } + + versionType := types.EntityTypeVersion + find, targetEntity := fs.FindDesiredEntity(file, viewerSession.Version, dep.HashIDEncoder(), &versionType) + if !find { + return serializer.NewError(serializer.CodeNotFound, "version not found", nil) + } + + if targetEntity.Size() > dep.SettingProvider().MaxOnlineEditSize(c) { + return fs.ErrFileSizeTooBig + } + + entitySource, err := m.GetEntitySource(c, targetEntity.ID(), fs.WithEntity(targetEntity)) + if err != nil { + return fmt.Errorf("failed to get entity source: %w", err) + } + + defer entitySource.Close() + + entitySource.Serve(c.Writer, c.Request, + entitysource.WithContext(c), + ) + + return nil +} + +func (service *WopiService) FileInfo(c *gin.Context) (*WopiFileInfo, error) { + uri, m, user, viewerSession, dep, err := prepareFs(c) + if err != nil { + return nil, err + } + + hasher := dep.HashIDEncoder() + settings := dep.SettingProvider() + + opts := []fs.Option{ + dbfs.WithFilePublicMetadata(), + dbfs.WithExtendedInfo(), + dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile, dbfs.NavigatorCapabilityInfo, dbfs.NavigatorCapabilityUploadFile), + } + file, err := m.Get(c, uri, opts...) + if err != nil { + return nil, fmt.Errorf("failed to get file: %w", err) + } + + if file == nil { + return nil, serializer.NewError(serializer.CodeNotFound, "file not found", nil) + } + + versionType := types.EntityTypeVersion + find, targetEntity := fs.FindDesiredEntity(file, viewerSession.Version, hasher, &versionType) + if !find { + return nil, serializer.NewError(serializer.CodeNotFound, "version not found", nil) + } + + canEdit := file.PrimaryEntityID() == targetEntity.ID() && file.OwnerID() == user.ID + siteUrl := settings.SiteURL(c) + info := &WopiFileInfo{ + BaseFileName: file.DisplayName(), + Version: hashid.EncodeEntityID(hasher, targetEntity.ID()), + BreadcrumbBrandName: settings.SiteBasic(c).Name, + BreadcrumbBrandUrl: siteUrl.String(), + FileSharingPostMessage: file.OwnerID() == user.ID, + EnableShare: file.OwnerID() == user.ID, + FileVersionPostMessage: true, + ClosePostMessage: true, + PostMessageOrigin: "*", + FileNameMaxLength: dbfs.MaxFileNameLength, + LastModifiedTime: file.UpdatedAt().Format(time.RFC3339), + IsAnonymousUser: inventory.IsAnonymousUser(user), + UserFriendlyName: user.Nick, + UserId: hashid.EncodeUserID(hasher, user.ID), + ReadOnly: !canEdit, + Size: targetEntity.Size(), + OwnerId: hashid.EncodeUserID(hasher, file.OwnerID()), + SupportsRename: true, + SupportsReviewing: true, + SupportsLocks: true, + UserCanReview: canEdit, + UserCanWrite: canEdit, + BreadcrumbFolderName: uri.Dir(), + BreadcrumbFolderUrl: routes.FrontendHomeUrl(siteUrl, uri.DirUri().String()).String(), + } + + return info, nil +} + +type ( + CreateViewerSessionService struct { + Uri string `json:"uri" form:"uri" binding:"required"` + Version string `json:"version" form:"version"` + ViewerID string `json:"viewer_id" form:"viewer_id" binding:"required"` + PreferredAction setting.ViewerAction `json:"preferred_action" form:"preferred_action" binding:"required"` + } + CreateViewerSessionParamCtx struct{} +) + +func (s *CreateViewerSessionService) Create(c *gin.Context) (*ViewerSessionResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + uri, err := fs.NewUriFromString(s.Uri) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "unknown uri", err) + } + + // Find the given viewer + viewers := dep.SettingProvider().FileViewers(c) + var targetViewer *setting.Viewer + for _, group := range viewers { + for _, viewer := range group.Viewers { + if viewer.ID == s.ViewerID && !viewer.Disabled { + targetViewer = &viewer + break + } + } + + if targetViewer != nil { + break + } + } + + if targetViewer == nil { + return nil, serializer.NewError(serializer.CodeParamErr, "unknown viewer id", err) + } + + viewerSession, err := m.CreateViewerSession(c, uri, s.Version, targetViewer) + if err != nil { + return nil, err + } + + res := &ViewerSessionResponse{Session: viewerSession} + if targetViewer.Type == setting.ViewerTypeWopi { + // For WOPI viewer, generate WOPI src + wopiSrc, err := wopi.GenerateWopiSrc(c, s.PreferredAction, targetViewer, viewerSession) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "failed to generate wopi src", err) + } + res.WopiSrc = wopiSrc.String() + } + + return res, nil +} diff --git a/service/explorer/wopi.go b/service/explorer/wopi.go deleted file mode 100644 index 9ee7c30e..00000000 --- a/service/explorer/wopi.go +++ /dev/null @@ -1,138 +0,0 @@ -package explorer - -import ( - "errors" - "fmt" - "github.com/cloudreve/Cloudreve/v3/middleware" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/gin-gonic/gin" - "net/http" - "time" -) - -type WopiService struct { -} - -func (service *WopiService) Rename(c *gin.Context) error { - fs, _, err := service.prepareFs(c) - if err != nil { - return err - } - - defer fs.Recycle() - - return fs.Rename(c, []uint{}, []uint{c.MustGet("object_id").(uint)}, c.GetHeader(wopi.RenameRequestHeader)) -} - -func (service *WopiService) GetFile(c *gin.Context) error { - fs, _, err := service.prepareFs(c) - if err != nil { - return err - } - - defer fs.Recycle() - - resp, err := fs.Preview(c, fs.FileTarget[0].ID, true) - if err != nil { - return fmt.Errorf("failed to pull file content: %w", err) - } - - // 重定向到文件源 - if resp.Redirect { - return fmt.Errorf("redirect not supported in WOPI") - } - - // 直接返回文件内容 - defer resp.Content.Close() - - c.Header("Cache-Control", "no-cache") - http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, fs.FileTarget[0].UpdatedAt, resp.Content) - return nil -} - -func (service *WopiService) FileInfo(c *gin.Context) (*serializer.WopiFileInfo, error) { - fs, session, err := service.prepareFs(c) - if err != nil { - return nil, err - } - - defer fs.Recycle() - - parent, err := model.GetFoldersByIDs([]uint{fs.FileTarget[0].FolderID}, fs.User.ID) - if err != nil { - return nil, err - } - - if len(parent) == 0 { - return nil, fmt.Errorf("failed to find parent folder") - } - - parent[0].TraceRoot() - siteUrl := model.GetSiteURL() - - // Generate url for parent folder - parentUrl := model.GetSiteURL() - parentUrl.Path = "/home" - query := parentUrl.Query() - query.Set("path", parent[0].Position) - parentUrl.RawQuery = query.Encode() - - info := &serializer.WopiFileInfo{ - BaseFileName: fs.FileTarget[0].Name, - Version: fs.FileTarget[0].Model.UpdatedAt.String(), - BreadcrumbBrandName: model.GetSettingByName("siteName"), - BreadcrumbBrandUrl: siteUrl.String(), - FileSharingPostMessage: false, - PostMessageOrigin: "*", - FileNameMaxLength: 256, - LastModifiedTime: fs.FileTarget[0].Model.UpdatedAt.Format(time.RFC3339), - IsAnonymousUser: true, - ReadOnly: true, - ClosePostMessage: true, - Size: int64(fs.FileTarget[0].Size), - OwnerId: hashid.HashID(fs.FileTarget[0].UserID, hashid.UserID), - } - - if session.Action == wopi.ActionEdit { - info.FileSharingPostMessage = true - info.IsAnonymousUser = false - info.SupportsRename = true - info.SupportsReviewing = true - info.SupportsUpdate = true - info.UserFriendlyName = fs.User.Nick - info.UserId = hashid.HashID(fs.User.ID, hashid.UserID) - info.UserCanRename = true - info.UserCanReview = true - info.UserCanWrite = true - info.ReadOnly = false - info.BreadcrumbFolderName = parent[0].Name - info.BreadcrumbFolderUrl = parentUrl.String() - } - - return info, nil -} - -func (service *WopiService) prepareFs(c *gin.Context) (*filesystem.FileSystem, *wopi.SessionCache, error) { - // 创建文件系统 - fs, err := filesystem.NewFileSystemFromContext(c) - if err != nil { - return nil, nil, err - } - - session := c.MustGet(middleware.WopiSessionCtx).(*wopi.SessionCache) - if err := fs.SetTargetFileByIDs([]uint{session.FileID}); err != nil { - fs.Recycle() - return nil, nil, fmt.Errorf("failed to find file: %w", err) - } - - maxSize := model.GetIntSetting("maxEditSize", 0) - if maxSize > 0 && fs.FileTarget[0].Size > uint64(maxSize) { - return nil, nil, errors.New("file too large") - } - - return fs, session, nil -} diff --git a/service/explorer/workflows.go b/service/explorer/workflows.go new file mode 100644 index 00000000..a6f614c7 --- /dev/null +++ b/service/explorer/workflows.go @@ -0,0 +1,398 @@ +package explorer + +import ( + "encoding/gob" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/downloader" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/workflows" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "github.com/samber/lo" + "golang.org/x/tools/container/intsets" +) + +// ItemMoveService 处理多文件/目录移动 +type ItemMoveService struct { + SrcDir string `json:"src_dir" binding:"required,min=1,max=65535"` + Src ItemIDService `json:"src"` + Dst string `json:"dst" binding:"required,min=1,max=65535"` +} + +// ItemRenameService 处理多文件/目录重命名 +type ItemRenameService struct { + Src ItemIDService `json:"src"` + NewName string `json:"new_name" binding:"required,min=1,max=255"` +} + +// ItemService 处理多文件/目录相关服务 +type ItemService struct { + Items []uint `json:"items"` + Dirs []uint `json:"dirs"` +} + +// ItemIDService 处理多文件/目录相关服务,字段值为HashID,可通过Raw()方法获取原始ID +type ItemIDService struct { + Items []string `json:"items"` + Dirs []string `json:"dirs"` + Source *ItemService + Force bool `json:"force"` + UnlinkOnly bool `json:"unlink"` +} + +// ItemDecompressService 文件解压缩任务服务 +type ItemDecompressService struct { + Src string `json:"src"` + Dst string `json:"dst" binding:"required,min=1,max=65535"` + Encoding string `json:"encoding"` +} + +// ItemPropertyService 获取对象属性服务 +type ItemPropertyService struct { + ID string `binding:"required"` + TraceRoot bool `form:"trace_root"` + IsFolder bool `form:"is_folder"` +} + +func init() { + gob.Register(ItemIDService{}) +} + +type ( + DownloadWorkflowService struct { + Src []string `json:"src"` + SrcFile string `json:"src_file"` + Dst string `json:"dst" binding:"required"` + } + CreateDownloadParamCtx struct{} +) + +func (service *DownloadWorkflowService) CreateDownloadTask(c *gin.Context) ([]*TaskResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionRemoteDownload)) { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "Group not allowed to download files", nil) + } + + // Src must be set + if service.SrcFile == "" && len(service.Src) == 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "No source files", nil) + } + + // Only one of src and src_file can be set + if service.SrcFile != "" && len(service.Src) > 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid source files", nil) + } + + dst, err := fs.NewUriFromString(service.Dst) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid destination", err) + } + + // Validate dst + _, err = m.Get(c, dst, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityCreateFile)) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid destination", err) + } + + // 检查批量任务数量 + limit := user.Edges.Group.Settings.Aria2BatchSize + if limit > 0 && len(service.Src) > limit { + return nil, serializer.NewError(serializer.CodeBatchAria2Size, "", nil) + } + + // Validate src file + if service.SrcFile != "" { + src, err := fs.NewUriFromString(service.SrcFile) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid source file uri", err) + } + + _, err = m.Get(c, src, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile)) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid source file", err) + } + } + + // batch creating tasks + ae := serializer.NewAggregateError() + tasks := make([]queue.Task, 0, len(service.Src)) + for _, src := range service.Src { + if src == "" { + continue + } + + t, err := workflows.NewRemoteDownloadTask(c, src, service.SrcFile, service.Dst) + if err != nil { + ae.Add(src, err) + continue + } + + if err := dep.RemoteDownloadQueue(c).QueueTask(c, t); err != nil { + ae.Add(src, err) + } + + tasks = append(tasks, t) + } + + if service.SrcFile != "" { + t, err := workflows.NewRemoteDownloadTask(c, "", service.SrcFile, service.Dst) + if err != nil { + ae.Add(service.SrcFile, err) + } + + if err := dep.RemoteDownloadQueue(c).QueueTask(c, t); err != nil { + ae.Add(service.SrcFile, err) + } + + tasks = append(tasks, t) + } + + return lo.Map(tasks, func(item queue.Task, index int) *TaskResponse { + return BuildTaskResponse(item, nil, hasher) + }), ae.Aggregate() +} + +type ( + ArchiveWorkflowService struct { + Src []string `json:"src" binding:"required"` + Dst string `json:"dst" binding:"required"` + Encoding string `json:"encoding"` + } + CreateArchiveParamCtx struct{} +) + +func (service *ArchiveWorkflowService) CreateExtractTask(c *gin.Context) (*TaskResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionArchiveTask)) { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "Group not allowed to compress files", nil) + } + + dst, err := fs.NewUriFromString(service.Dst) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid destination", err) + } + + if len(service.Src) == 0 { + return nil, serializer.NewError(serializer.CodeParamErr, "No source files", nil) + } + + // Validate destination + if _, err := m.Get(c, dst, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityCreateFile)); err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid destination", err) + } + + // Create task + t, err := workflows.NewExtractArchiveTask(c, service.Src[0], service.Dst, service.Encoding) + if err != nil { + return nil, serializer.NewError(serializer.CodeCreateTaskError, "Failed to create task", err) + } + + if err := dep.IoIntenseQueue(c).QueueTask(c, t); err != nil { + return nil, serializer.NewError(serializer.CodeCreateTaskError, "Failed to queue task", err) + } + + return BuildTaskResponse(t, nil, hasher), nil +} + +// CreateCompressTask Create task to create an archive file +func (service *ArchiveWorkflowService) CreateCompressTask(c *gin.Context) (*TaskResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionArchiveTask)) { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "Group not allowed to compress files", nil) + } + + dst, err := fs.NewUriFromString(service.Dst) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid destination", err) + } + + // Create a placeholder file then delete it to validate the destination + session, err := m.PrepareUpload(c, &fs.UploadRequest{ + Props: &fs.UploadProps{ + Uri: dst, + Size: 0, + UploadSessionID: uuid.Must(uuid.NewV4()).String(), + ExpireAt: time.Now().Add(time.Second * 3600), + }, + }) + if err != nil { + return nil, err + } + m.OnUploadFailed(c, session) + + // Create task + t, err := workflows.NewCreateArchiveTask(c, service.Src, service.Dst) + if err != nil { + return nil, serializer.NewError(serializer.CodeCreateTaskError, "Failed to create task", err) + } + + if err := dep.IoIntenseQueue(c).QueueTask(c, t); err != nil { + return nil, serializer.NewError(serializer.CodeCreateTaskError, "Failed to queue task", err) + } + + return BuildTaskResponse(t, nil, hasher), nil +} + +type ( + ListTaskService struct { + PageSize int `form:"page_size" binding:"required,min=10,max=100"` + Category string `form:"category" binding:"required,eq=general|eq=downloading|eq=downloaded"` + NextPageToken string `form:"next_page_token"` + } + ListTaskParamCtx struct{} +) + +func (service *ListTaskService) ListTasks(c *gin.Context) (*TaskListResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + taskClient := dep.TaskClient() + + args := &inventory.ListTaskArgs{ + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageToken: service.NextPageToken, + PageSize: service.PageSize, + }, + Types: []string{queue.CreateArchiveTaskType, queue.ExtractArchiveTaskType, queue.RelocateTaskType}, + UserID: user.ID, + } + + if service.Category != "general" { + args.Types = []string{queue.RemoteDownloadTaskType} + if service.Category == "downloading" { + args.PageSize = intsets.MaxInt + args.Status = []task.Status{task.StatusSuspending, task.StatusProcessing, task.StatusQueued} + } else if service.Category == "downloaded" { + args.Status = []task.Status{task.StatusCanceled, task.StatusError, task.StatusCompleted} + } + } + + // Get tasks + res, err := taskClient.List(c, args) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to query tasks", err) + } + + tasks := make([]queue.Task, 0, len(res.Tasks)) + nodeMap := make(map[int]*ent.Node) + for _, t := range res.Tasks { + task, err := queue.NewTaskFromModel(t) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to parse task", err) + } + + summary := task.Summarize(hasher) + if summary != nil && summary.NodeID > 0 { + if _, ok := nodeMap[summary.NodeID]; !ok { + nodeMap[summary.NodeID] = nil + } + } + tasks = append(tasks, task) + } + + // Get nodes + nodes, err := dep.NodeClient().ListActiveNodes(c, lo.Keys(nodeMap)) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to query nodes", err) + } + for _, n := range nodes { + nodeMap[n.ID] = n + } + + // Build response + return BuildTaskListResponse(tasks, res, nodeMap, hasher), nil +} + +func TaskPhaseProgress(c *gin.Context, taskID int) (queue.Progresses, error) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + r := dep.TaskRegistry() + t, found := r.Get(taskID) + if !found || t.Owner().ID != u.ID { + return queue.Progresses{}, nil + } + + return t.Progress(c), nil +} + +func CancelDownloadTask(c *gin.Context, taskID int) error { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + r := dep.TaskRegistry() + t, found := r.Get(taskID) + if !found || t.Owner().ID != u.ID { + return serializer.NewError(serializer.CodeNotFound, "Task not found", nil) + } + + if downloadTask, ok := t.(*workflows.RemoteDownloadTask); ok { + if err := downloadTask.CancelDownload(c); err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to cancel download task", err) + } + } + + return nil +} + +type ( + SetDownloadFilesService struct { + Files []*downloader.SetFileToDownloadArgs `json:"files" binding:"required"` + } + SetDownloadFilesParamCtx struct{} +) + +func (service *SetDownloadFilesService) SetDownloadFiles(c *gin.Context, taskID int) error { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + r := dep.TaskRegistry() + + t, found := r.Get(taskID) + if !found || t.Owner().ID != u.ID { + return serializer.NewError(serializer.CodeNotFound, "Task not found", nil) + } + + status := t.Status() + summary := t.Summarize(dep.HashIDEncoder()) + // Task must be in processing state + if status != task.StatusSuspending && status != task.StatusProcessing { + return serializer.NewError(serializer.CodeNotFound, "Task not in processing state", nil) + } + + // Task must in monitoring loop + if summary.Phase != workflows.RemoteDownloadTaskPhaseMonitor { + return serializer.NewError(serializer.CodeNotFound, "Task not in monitoring loop", nil) + } + + if downloadTask, ok := t.(*workflows.RemoteDownloadTask); ok { + if err := downloadTask.SetDownloadTarget(c, service.Files...); err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to set download files", err) + } + } + + return nil +} diff --git a/service/node/fabric.go b/service/node/fabric.go deleted file mode 100644 index deb21840..00000000 --- a/service/node/fabric.go +++ /dev/null @@ -1,76 +0,0 @@ -package node - -import ( - "encoding/gob" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" -) - -type SlaveNotificationService struct { - Subject string `uri:"subject" binding:"required"` -} - -type OauthCredentialService struct { - PolicyID uint `uri:"id" binding:"required"` -} - -func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { - res, err := cluster.DefaultController.HandleHeartBeat(req) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err) - } - - return serializer.Response{ - Code: 0, - Data: res, - } -} - -// HandleSlaveNotificationPush 转发从机的消息通知到本机消息队列 -func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) serializer.Response { - var msg mq.Message - dec := gob.NewDecoder(c.Request.Body) - if err := dec.Decode(&msg); err != nil { - return serializer.ParamErr("Cannot parse notification message", err) - } - - mq.GlobalMQ.Publish(s.Subject, msg) - return serializer.Response{} -} - -// Get 获取主机Oauth策略的AccessToken -func (s *OauthCredentialService) Get(c *gin.Context) serializer.Response { - policy, err := model.GetPolicyByID(s.PolicyID) - if err != nil { - return serializer.Err(serializer.CodePolicyNotExist, "", err) - } - - var client oauth.TokenProvider - switch policy.Type { - case "onedrive": - client, err = onedrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err) - } - case "googledrive": - client, err = googledrive.NewClient(&policy) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize Google Drive client", err) - } - default: - return serializer.Err(serializer.CodePolicyNotExist, "", nil) - } - - if err := client.UpdateCredential(c, conf.SystemConfig.Mode == "slave"); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Cannot refresh OneDrive credential", err) - } - - return serializer.Response{Data: client.AccessToken()} -} diff --git a/service/node/response.go b/service/node/response.go new file mode 100644 index 00000000..2b4023a6 --- /dev/null +++ b/service/node/response.go @@ -0,0 +1 @@ +package node diff --git a/service/node/rpc.go b/service/node/rpc.go new file mode 100644 index 00000000..4fadf19d --- /dev/null +++ b/service/node/rpc.go @@ -0,0 +1,120 @@ +package node + +import ( + "context" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/credmanager" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-gonic/gin" +) + +type SlaveNotificationService struct { + Subject string `uri:"subject" binding:"required"` +} + +type ( + OauthCredentialParamCtx struct{} + OauthCredentialService struct { + ID string `uri:"id" binding:"required"` + } +) + +// Get 获取主机Oauth策略的AccessToken +func (s *OauthCredentialService) Get(c *gin.Context) (*credmanager.CredentialResponse, error) { + dep := dependency.FromContext(c) + credManager := dep.CredManager() + + cred, err := credManager.Obtain(c, s.ID) + if cred == nil || err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "Credential not found", err) + } + + return &credmanager.CredentialResponse{ + Token: cred.String(), + ExpireAt: cred.Expiry(), + }, nil +} + +type ( + StatelessPrepareUploadParamCtx struct{} +) + +func StatelessPrepareUpload(s *fs.StatelessPrepareUploadService, c *gin.Context) (*fs.StatelessPrepareUploadResponse, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + user, err := userClient.GetLoginUserByID(c, s.UserID) + if err != nil { + return nil, err + } + + ctx := context.WithValue(c.Request.Context(), inventory.UserCtx{}, user) + fm := manager.NewFileManager(dep, user) + uploadSession, err := fm.PrepareUpload(ctx, s.UploadRequest) + if err != nil { + return nil, err + } + return &fs.StatelessPrepareUploadResponse{ + Session: uploadSession, + Req: s.UploadRequest, + }, nil +} + +type ( + StatelessCompleteUploadParamCtx struct{} +) + +func StatelessCompleteUpload(s *fs.StatelessCompleteUploadService, c *gin.Context) (fs.File, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + user, err := userClient.GetLoginUserByID(c, s.UserID) + if err != nil { + return nil, err + } + + util.WithValue(c, inventory.UserCtx{}, user) + fm := manager.NewFileManager(dep, user) + return fm.CompleteUpload(c, s.UploadSession) +} + +type ( + StatelessOnUploadFailedParamCtx struct{} +) + +func StatelessOnUploadFailed(s *fs.StatelessOnUploadFailedService, c *gin.Context) error { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + user, err := userClient.GetLoginUserByID(c, s.UserID) + if err != nil { + return err + } + + util.WithValue(c, inventory.UserCtx{}, user) + fm := manager.NewFileManager(dep, user) + fm.OnUploadFailed(c, s.UploadSession) + return nil +} + +type StatelessCreateFileParamCtx struct{} + +func StatelessCreateFile(s *fs.StatelessCreateFileService, c *gin.Context) error { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + user, err := userClient.GetLoginUserByID(c, s.UserID) + if err != nil { + return err + } + + uri, err := fs.NewUriFromString(s.Path) + if err != nil { + return err + } + + util.WithValue(c, inventory.UserCtx{}, user) + fm := manager.NewFileManager(dep, user) + _, err = fm.Create(c, uri, s.Type) + return err +} diff --git a/service/node/task.go b/service/node/task.go new file mode 100644 index 00000000..85cc76a8 --- /dev/null +++ b/service/node/task.go @@ -0,0 +1,150 @@ +package node + +import ( + "context" + "fmt" + "os" + "strconv" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent/task" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/workflows" + "github.com/cloudreve/Cloudreve/v4/pkg/logging" + "github.com/cloudreve/Cloudreve/v4/pkg/queue" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" +) + +type ( + CreateSlaveTaskParamCtx struct{} +) + +func CreateTaskInSlave(s *cluster.CreateSlaveTask, c *gin.Context) (int, error) { + dep := dependency.FromContext(c) + registry := dep.TaskRegistry() + + props, err := slaveTaskPropsFromContext(c) + if err != nil { + return 0, serializer.NewError(serializer.CodeParamErr, "failed to get master props from header", err) + } + + var t queue.Task + switch s.Type { + case queue.SlaveUploadTaskType: + t = workflows.NewSlaveUploadTask(c, props, registry.NextID(), s.State) + case queue.SlaveCreateArchiveTaskType: + t = workflows.NewSlaveCreateArchiveTask(c, props, registry.NextID(), s.State) + case queue.SlaveExtractArchiveType: + t = workflows.NewSlaveExtractArchiveTask(c, props, registry.NextID(), s.State) + default: + return 0, serializer.NewError(serializer.CodeParamErr, "type not supported", nil) + } + + if err := dep.SlaveQueue(c).QueueTask(c, t); err != nil { + return 0, serializer.NewError(serializer.CodeInternalSetting, "failed to queue task", err) + } + + registry.Set(t.ID(), t) + return t.ID(), nil +} + +type ( + GetSlaveTaskParamCtx struct{} + GetSlaveTaskService struct { + ID int `uri:"id" binding:"required"` + } +) + +func (s *GetSlaveTaskService) Get(c *gin.Context) (*cluster.SlaveTaskSummary, error) { + dep := dependency.FromContext(c) + registry := dep.TaskRegistry() + + t, ok := registry.Get(s.ID) + if !ok { + return nil, serializer.NewError(serializer.CodeNotFound, "task not found", nil) + } + status := t.Status() + _, clearOnComplete := c.GetQuery(routes.SlaveClearTaskRegistryQuery) + if clearOnComplete && status == task.StatusCompleted || + status == task.StatusError || + status == task.StatusCanceled { + registry.Delete(s.ID) + } + + res := &cluster.SlaveTaskSummary{ + Status: status, + PrivateState: t.State(), + Progress: t.Progress(c), + } + err := t.Error() + if err != nil { + res.Error = err.Error() + } + + return res, nil +} + +func slaveTaskPropsFromContext(ctx context.Context) (*types.SlaveTaskProps, error) { + nodeIdStr, ok := ctx.Value(cluster.SlaveNodeIDCtx{}).(string) + if !ok { + return nil, fmt.Errorf("failed to get node ID from context") + } + + nodeId, err := strconv.Atoi(nodeIdStr) + if err != nil { + return nil, fmt.Errorf("failed to convert node ID to int: %w", err) + } + + masterSiteUrl := cluster.MasterSiteUrlFromContext(ctx) + if masterSiteUrl == "" { + return nil, fmt.Errorf("failed to get master site URL from context") + } + + masterSiteVersion, ok := ctx.Value(cluster.MasterSiteVersionCtx{}).(string) + if !ok { + return nil, fmt.Errorf("failed to get master site version from context") + } + + masterSiteId, ok := ctx.Value(cluster.MasterSiteIDCtx{}).(string) + if !ok { + return nil, fmt.Errorf("failed to convert master site ID to int: %w", err) + } + + props := &types.SlaveTaskProps{ + NodeID: nodeId, + MasterSiteID: masterSiteId, + MasterSiteURl: masterSiteUrl, + MasterSiteVersion: masterSiteVersion, + } + + return props, nil +} + +type ( + FolderCleanupParamCtx struct{} +) + +func Cleanup(args *cluster.FolderCleanup, c *gin.Context) error { + l := logging.FromContext(c) + ae := serializer.NewAggregateError() + for _, p := range args.Path { + l.Info("Cleaning up folder %q", p) + if err := os.RemoveAll(p); err != nil { + l.Warning("Failed to clean up folder %q: %s", p, err) + ae.Add(p, err) + } + } + + return ae.Aggregate() +} + +type ( + CreateSlaveDownloadTaskParamCtx struct{} + GetSlaveDownloadTaskParamCtx struct{} + CancelSlaveDownloadTaskParamCtx struct{} + SelectSlaveDownloadFilesParamCtx struct{} + TestSlaveDownloadParamCtx struct{} +) diff --git a/service/setting/response.go b/service/setting/response.go new file mode 100644 index 00000000..165a4490 --- /dev/null +++ b/service/setting/response.go @@ -0,0 +1,44 @@ +package setting + +import ( + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/samber/lo" + "time" +) + +type ListDavAccountResponse struct { + Accounts []DavAccount `json:"accounts"` + Pagination *inventory.PaginationResults `json:"pagination"` +} + +func BuildListDavAccountResponse(res *inventory.ListDavAccountResult, hasher hashid.Encoder) *ListDavAccountResponse { + return &ListDavAccountResponse{ + Accounts: lo.Map(res.Accounts, func(item *ent.DavAccount, index int) DavAccount { + return BuildDavAccount(item, hasher) + }), + Pagination: res.PaginationResults, + } +} + +type DavAccount struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + Name string `json:"name"` + Uri string `json:"uri"` + Password string `json:"password"` + Options *boolset.BooleanSet `json:"options"` +} + +func BuildDavAccount(account *ent.DavAccount, hasher hashid.Encoder) DavAccount { + return DavAccount{ + ID: hashid.EncodeDavAccountID(hasher, account.ID), + CreatedAt: account.CreatedAt, + Name: account.Name, + Uri: account.URI, + Password: account.Password, + Options: account.Options, + } +} diff --git a/service/setting/webdav.go b/service/setting/webdav.go index 1f817516..640776c7 100644 --- a/service/setting/webdav.go +++ b/service/setting/webdav.go @@ -1,16 +1,19 @@ package setting import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v4/application/constants" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" ) -// WebDAVListService WebDAV 列表服务 -type WebDAVListService struct { -} - // WebDAVAccountService WebDAV 账号管理服务 type WebDAVAccountService struct { ID uint `uri:"id" binding:"required,min=1"` @@ -22,11 +25,10 @@ type WebDAVAccountCreateService struct { Name string `json:"name" binding:"required,min=1,max=255"` } -// WebDAVAccountUpdateService WebDAV 修改只读性和是否使用代理服务 -type WebDAVAccountUpdateService struct { - ID uint `json:"id" binding:"required,min=1"` - Readonly *bool `json:"readonly" binding:"required_without=UseProxy"` - UseProxy *bool `json:"use_proxy" binding:"required_without=Readonly"` +// WebDAVAccountUpdateReadonlyService WebDAV 修改只读性服务 +type WebDAVAccountUpdateReadonlyService struct { + ID uint `json:"id" binding:"required,min=1"` + Readonly bool `json:"readonly"` } // WebDAVMountCreateService WebDAV 挂载创建服务 @@ -35,52 +37,163 @@ type WebDAVMountCreateService struct { Policy string `json:"policy" binding:"required,min=1"` } -// Create 创建WebDAV账户 -func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.User) serializer.Response { - account := model.Webdav{ - Name: service.Name, - Password: util.RandStringRunes(32), - UserID: user.ID, - Root: service.Path, - } +//// Unmount 取消目录挂载 +//func (service *WebDAVListService) Unmount(c *gin.Context, user *model.User) serializer.Response { +// folderID, _ := c.Get("object_id") +// folder, err := model.GetFoldersByIDs([]uint{folderID.(uint)}, user.ID) +// if err != nil || len(folder) == 0 { +// return serializer.ErrDeprecated(serializer.CodeParentNotExist, "", err) +// } +// +// if err := folder[0].Mount(0); err != nil { +// return serializer.DBErrDeprecated("Failed to update folder record", err) +// } +// +// return serializer.Response{} +//} - if _, err := account.Create(); err != nil { - return serializer.Err(serializer.CodeDBError, "创建失败", err) +type ( + ListDavAccountsService struct { + PageSize int `form:"page_size" binding:"required,min=10,max=100"` + NextPageToken string `form:"next_page_token"` } + ListDavAccountParamCtx struct{} +) - return serializer.Response{ - Data: map[string]interface{}{ - "id": account.ID, - "password": account.Password, - "created_at": account.CreatedAt, +// Accounts 列出WebDAV账号 +func (service *ListDavAccountsService) List(c *gin.Context) (*ListDavAccountResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + davAccountClient := dep.DavAccountClient() + + args := &inventory.ListDavAccountArgs{ + UserID: user.ID, + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageSize: service.PageSize, + PageToken: service.NextPageToken, }, } + + res, err := davAccountClient.List(c, args) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list dav accounts", err) + } + + return BuildListDavAccountResponse(res, hasher), nil +} + +type ( + CreateDavAccountService struct { + Uri string `json:"uri" binding:"required"` + Name string `json:"name" binding:"required,min=1,max=255"` + Readonly bool `json:"readonly"` + Proxy bool `json:"proxy"` + } + CreateDavAccountParamCtx struct{} +) + +// Create 创建WebDAV账号 +func (service *CreateDavAccountService) Create(c *gin.Context) (*DavAccount, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + + bs, err := service.validateAndGetBs(user) + if err != nil { + return nil, err + } + + davAccountClient := dep.DavAccountClient() + account, err := davAccountClient.Create(c, &inventory.CreateDavAccountParams{ + UserID: user.ID, + Name: service.Name, + URI: service.Uri, + Password: util.RandString(32, util.RandomLowerCases), + Options: bs, + }) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to create dav account", err) + } + + accountRes := BuildDavAccount(account, dep.HashIDEncoder()) + return &accountRes, nil } -// Delete 删除WebDAV账户 -func (service *WebDAVAccountService) Delete(c *gin.Context, user *model.User) serializer.Response { - model.DeleteWebDAVAccountByID(service.ID, user.ID) - return serializer.Response{} +// Update updates an existing account +func (service *CreateDavAccountService) Update(c *gin.Context) (*DavAccount, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + accountId := hashid.FromContext(c) + + // Find existing account + davAccountClient := dep.DavAccountClient() + account, err := davAccountClient.GetByIDAndUserID(c, accountId, user.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "Account not exist", err) + } + + bs, err := service.validateAndGetBs(user) + if err != nil { + return nil, err + } + + // Update account + account, err = davAccountClient.Update(c, accountId, &inventory.CreateDavAccountParams{ + Name: service.Name, + URI: service.Uri, + Options: bs, + }) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update dav account", err) + } + + accountRes := BuildDavAccount(account, dep.HashIDEncoder()) + return &accountRes, nil } -// Update 修改WebDAV账户只读性和是否使用代理服务 -func (service *WebDAVAccountUpdateService) Update(c *gin.Context, user *model.User) serializer.Response { - var updates = make(map[string]interface{}) - if service.Readonly != nil { - updates["readonly"] = *service.Readonly +func (service *CreateDavAccountService) validateAndGetBs(user *ent.User) (*boolset.BooleanSet, error) { + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionWebDAV)) { + return nil, serializer.NewError(serializer.CodeGroupNotAllowed, "WebDAV is not enabled for this user group", nil) } - if service.UseProxy != nil { - updates["use_proxy"] = *service.UseProxy + + uri, err := fs.NewUriFromString(service.Uri) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid URI", err) } - model.UpdateWebDAVAccountByID(service.ID, user.ID, updates) - return serializer.Response{Data: updates} + + // Only "my" and "share" fs is allowed in WebDAV + if uriFs := uri.FileSystem(); uri.SearchParameters() != nil || + (uriFs != constants.FileSystemMy && uriFs != constants.FileSystemShare) { + return nil, serializer.NewError(serializer.CodeParamErr, "Invalid URI", nil) + } + + bs := boolset.BooleanSet{} + if service.Readonly { + boolset.Set(types.DavAccountReadOnly, true, &bs) + } + + if service.Proxy && user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionWebDAVProxy)) { + boolset.Set(types.DavAccountProxy, true, &bs) + } + return &bs, nil } -// Accounts 列出WebDAV账号 -func (service *WebDAVListService) Accounts(c *gin.Context, user *model.User) serializer.Response { - accounts := model.ListWebDAVAccounts(user.ID) +func DeleteDavAccount(c *gin.Context) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + accountId := hashid.FromContext(c) + + // Find existing account + davAccountClient := dep.DavAccountClient() + _, err := davAccountClient.GetByIDAndUserID(c, accountId, user.ID) + if err != nil { + return serializer.NewError(serializer.CodeNotFound, "Account not exist", err) + } + + if err := davAccountClient.Delete(c, accountId); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete dav account", err) + } - return serializer.Response{Data: map[string]interface{}{ - "accounts": accounts, - }} + return nil } diff --git a/service/share/manage.go b/service/share/manage.go index 9daccdb4..2d0de0a2 100644 --- a/service/share/manage.go +++ b/service/share/manage.go @@ -1,150 +1,90 @@ package share import ( - "net/url" + "context" "time" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/explorer" "github.com/gin-gonic/gin" ) -// ShareCreateService 创建新分享服务 -type ShareCreateService struct { - SourceID string `json:"id" binding:"required"` - IsDir bool `json:"is_dir"` - Password string `json:"password" binding:"max=255"` - RemainDownloads int `json:"downloads"` - Expire int `json:"expire"` - Preview bool `json:"preview"` -} +type ( + // ShareCreateService 创建新分享服务 + ShareCreateService struct { + Uri string `json:"uri" binding:"required"` + IsPrivate bool `json:"is_private"` + RemainDownloads int `json:"downloads"` + Expire int `json:"expire"` + } + ShareCreateParamCtx struct{} +) -// ShareUpdateService 分享更新服务 -type ShareUpdateService struct { - Prop string `json:"prop" binding:"required,eq=password|eq=preview_enabled"` - Value string `json:"value" binding:"max=255"` -} +// Upsert 创建或更新分享 +func (service *ShareCreateService) Upsert(c *gin.Context, existed int) (string, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() -// Delete 删除分享 -func (service *Service) Delete(c *gin.Context, user *model.User) serializer.Response { - share := model.GetShareByHashID(c.Param("id")) - if share == nil || share.Creator().ID != user.ID { - return serializer.Err(serializer.CodeShareLinkNotFound, "", nil) + // Check group permission for creating share link + if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionShare)) { + return "", serializer.NewError(serializer.CodeGroupNotAllowed, "Group permission denied", nil) } - if err := share.Delete(); err != nil { - return serializer.DBErr("Failed to delete share record", err) + uri, err := fs.NewUriFromString(service.Uri) + if err != nil { + return "", serializer.NewError(serializer.CodeParamErr, "unknown uri", err) } - return serializer.Response{} -} - -// Update 更新分享属性 -func (service *ShareUpdateService) Update(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - switch service.Prop { - case "password": - err := share.Update(map[string]interface{}{"password": service.Value}) - if err != nil { - return serializer.DBErr("Failed to update share record", err) - } - case "preview_enabled": - value := service.Value == "true" - err := share.Update(map[string]interface{}{"preview_enabled": value}) - if err != nil { - return serializer.DBErr("Failed to update share record", err) - } - return serializer.Response{ - Data: value, - } + var expires *time.Time + if service.Expire > 0 { + expires = new(time.Time) + *expires = time.Now().Add(time.Duration(service.Expire) * time.Second) } - return serializer.Response{ - Data: service.Value, + + share, err := m.CreateOrUpdateShare(c, uri, &manager.CreateShareArgs{ + IsPrivate: service.IsPrivate, + RemainDownloads: service.RemainDownloads, + Expire: expires, + ExistedShareID: existed, + }) + if err != nil { + return "", err } -} -// Create 创建新分享 -func (service *ShareCreateService) Create(c *gin.Context) serializer.Response { - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) + base := dep.SettingProvider().SiteURL(c) + return explorer.BuildShareLink(share, dep.HashIDEncoder(), base), nil +} - // 是否拥有权限 - if !user.Group.ShareEnabled { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } +func DeleteShare(c *gin.Context, shareId int) error { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + shareClient := dep.ShareClient() - // 源对象真实ID + ctx := context.WithValue(c, inventory.LoadShareFile{}, true) var ( - sourceID uint - sourceName string - err error + share *ent.Share + err error ) - if service.IsDir { - sourceID, err = hashid.DecodeHashID(service.SourceID, hashid.FolderID) + if user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) { + share, err = shareClient.GetByID(ctx, shareId) } else { - sourceID, err = hashid.DecodeHashID(service.SourceID, hashid.FileID) + share, err = shareClient.GetByIDUser(ctx, shareId, user.ID) } if err != nil { - return serializer.Err(serializer.CodeNotFound, "", nil) + return serializer.NewError(serializer.CodeNotFound, "share not found", err) } - // 对象是否存在 - exist := true - if service.IsDir { - folder, err := model.GetFoldersByIDs([]uint{sourceID}, user.ID) - if err != nil || len(folder) == 0 { - exist = false - } else { - sourceName = folder[0].Name - } - } else { - file, err := model.GetFilesByIDs([]uint{sourceID}, user.ID) - if err != nil || len(file) == 0 { - exist = false - } else { - sourceName = file[0].Name - } - } - if !exist { - return serializer.Err(serializer.CodeNotFound, "", nil) - } - - newShare := model.Share{ - Password: service.Password, - IsDir: service.IsDir, - UserID: user.ID, - SourceID: sourceID, - RemainDownloads: -1, - PreviewEnabled: service.Preview, - SourceName: sourceName, - } - - // 如果开启了自动过期 - if service.RemainDownloads > 0 { - expires := time.Now().Add(time.Duration(service.Expire) * time.Second) - newShare.RemainDownloads = service.RemainDownloads - newShare.Expires = &expires - } - - // 创建分享 - id, err := newShare.Create() - if err != nil { - return serializer.DBErr("Failed to create share link record", err) - } - - // 获取分享的唯一id - uid := hashid.HashID(id, hashid.ShareID) - // 最终得到分享链接 - siteURL := model.GetSiteURL() - sharePath, _ := url.Parse("/s/" + uid) - shareURL := siteURL.ResolveReference(sharePath) - - return serializer.Response{ - Code: 0, - Data: shareURL.String(), + if err := shareClient.Delete(c, share.ID); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete share", err) } + return nil } diff --git a/service/share/response.go b/service/share/response.go new file mode 100644 index 00000000..cd3f6a47 --- /dev/null +++ b/service/share/response.go @@ -0,0 +1,28 @@ +package share + +import ( + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/service/explorer" + "net/url" +) + +type ListShareResponse struct { + Shares []explorer.Share `json:"shares"` + Pagination *inventory.PaginationResults `json:"pagination"` +} + +func BuildListShareResponse(res *inventory.ListShareResult, hasher hashid.Encoder, base *url.URL, requester *ent.User, unlocked bool) *ListShareResponse { + var infos []explorer.Share + for _, share := range res.Shares { + infos = append(infos, *explorer.BuildShare(share, base, hasher, requester, share.Edges.User, share.Edges.File.Name, + types.FileType(share.Edges.File.Type), unlocked)) + } + + return &ListShareResponse{ + Shares: infos, + Pagination: res.PaginationResults, + } +} diff --git a/service/share/visit.go b/service/share/visit.go index caad0601..462ce45f 100644 --- a/service/share/visit.go +++ b/service/share/visit.go @@ -2,414 +2,159 @@ package share import ( "context" - "fmt" - "net/http" - "path" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/cloudreve/Cloudreve/v3/service/explorer" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/service/explorer" "github.com/gin-gonic/gin" ) -// ShareUserGetService 获取用户的分享服务 -type ShareUserGetService struct { - Type string `form:"type" binding:"required,eq=hot|eq=default"` - Page uint `form:"page" binding:"required,min=1"` -} - -// ShareGetService 获取分享服务 -type ShareGetService struct { - Password string `form:"password" binding:"max=255"` -} - -// Service 对分享进行操作的服务, -// path 为可选文件完整路径,在目录分享下有效 -type Service struct { - Path string `form:"path" uri:"path" binding:"max=65535"` -} - -// ArchiveService 分享归档下载服务 -type ArchiveService struct { - Path string `json:"path" binding:"required,max=65535"` - Items []string `json:"items"` - Dirs []string `json:"dirs"` -} - -// ShareListService 列出分享 -type ShareListService struct { - Page uint `form:"page" binding:"required,min=1"` - OrderBy string `form:"order_by" binding:"required,eq=created_at|eq=downloads|eq=views"` - Order string `form:"order" binding:"required,eq=DESC|eq=ASC"` - Keywords string `form:"keywords"` -} - -// Get 获取给定用户的分享 -func (service *ShareUserGetService) Get(c *gin.Context) serializer.Response { - // 取得用户 - userID, _ := c.Get("object_id") - user, err := model.GetActiveUserByID(userID.(uint)) - if err != nil || user.OptionsSerialized.ProfileOff { - return serializer.Err(serializer.CodeNotFound, "", err) - } - - // 列出分享 - hotNum := model.GetIntSetting("hot_share_num", 10) - if service.Type == "default" { - hotNum = 10 - } - orderBy := "created_at desc" - if service.Type == "hot" { - orderBy = "views desc" - } - shares, total := model.ListShares(user.ID, int(service.Page), hotNum, orderBy, true) - // 列出分享对应的文件 - for i := 0; i < len(shares); i++ { - shares[i].Source() - } - - res := serializer.BuildShareList(shares, total) - res.Data.(map[string]interface{})["user"] = struct { - ID string `json:"id"` - Nick string `json:"nick"` - Group string `json:"group"` - Date string `json:"date"` - }{ - hashid.HashID(user.ID, hashid.UserID), - user.Nick, - user.Group.Name, - user.CreatedAt.Format("2006-01-02 15:04:05"), - } - - return res -} - -// Search 搜索公共分享 -func (service *ShareListService) Search(c *gin.Context) serializer.Response { - // 列出分享 - shares, total := model.SearchShares(int(service.Page), 18, service.OrderBy+" "+ - service.Order, service.Keywords) - // 列出分享对应的文件 - for i := 0; i < len(shares); i++ { - shares[i].Source() +type ( + ShortLinkRedirectService struct { + ID string `uri:"id" binding:"required"` + Password string `uri:"password"` } + ShortLinkRedirectParamCtx struct{} +) - return serializer.BuildShareList(shares, total) +func (s *ShortLinkRedirectService) RedirectTo(c *gin.Context) string { + return routes.MasterShareLongUrl(s.ID, s.Password).String() } -// List 列出用户分享 -func (service *ShareListService) List(c *gin.Context, user *model.User) serializer.Response { - // 列出分享 - shares, total := model.ListShares(user.ID, int(service.Page), 18, service.OrderBy+" "+ - service.Order, false) - // 列出分享对应的文件 - for i := 0; i < len(shares); i++ { - shares[i].Source() +type ( + ShareInfoService struct { + Password string `form:"password"` + CountViews bool `form:"count_views"` + OwnerExtended bool `form:"owner_extended"` } + ShareInfoParamCtx struct{} +) - return serializer.BuildShareList(shares, total) -} - -// Get 获取分享内容 -func (service *ShareGetService) Get(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) +func (s *ShareInfoService) Get(c *gin.Context) (*explorer.Share, error) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + shareClient := dep.ShareClient() - // 是否已解锁 - unlocked := true - if share.Password != "" { - sessionKey := fmt.Sprintf("share_unlock_%d", share.ID) - unlocked = util.GetSession(c, sessionKey) != nil - if !unlocked && service.Password != "" { - // 如果未解锁,且指定了密码,则尝试解锁 - if service.Password == share.Password { - unlocked = true - util.SetSession(c, map[string]interface{}{sessionKey: true}) - } + ctx := context.WithValue(c, inventory.LoadShareUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareFile{}, true) + share, err := shareClient.GetByID(ctx, hashid.FromContext(c)) + if err != nil { + if ent.IsNotFound(err) { + return nil, serializer.NewError(serializer.CodeNotFound, "Share not found", nil) } + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get share", err) } - if unlocked { - share.Viewed() + if err := inventory.IsValidShare(share); err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "Share link expired", err) } - return serializer.Response{ - Code: 0, - Data: serializer.BuildShareResponse(share, unlocked), + if s.CountViews { + _ = shareClient.Viewed(c, share) } -} -// CreateDownloadSession 创建下载会话 -func (service *Service) CreateDownloadSession(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) - - // 创建文件系统 - fs, err := filesystem.NewFileSystem(user) - if err != nil { - return serializer.DBErr("Failed to update share record", err) - } - defer fs.Recycle() - - // 重设文件系统处理目标为源文件 - err = fs.SetTargetByInterface(share.Source()) - if err != nil { - return serializer.Err(serializer.CodeFileNotFound, "", err) + unlocked := true + // Share requires password + if share.Password != "" && s.Password != share.Password && share.Edges.User.ID != u.ID { + unlocked = false } - ctx := context.Background() + base := dep.SettingProvider().SiteURL(c) + res := explorer.BuildShare(share, base, dep.HashIDEncoder(), u, share.Edges.User, share.Edges.File.Name, + types.FileType(share.Edges.File.Type), unlocked) - // 重设根目录 - if share.IsDir { - fs.Root = &fs.DirTarget[0] + if s.OwnerExtended && share.Edges.User.ID == u.ID { + // Add more information about the shared file + m := manager.NewFileManager(dep, u) + defer m.Recycle() - // 找到目标文件 - err = fs.ResetFileIfNotExist(ctx, service.Path) + shareUri, err := fs.NewUriFromString(fs.NewShareUri(res.ID, s.Password)) if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) + return nil, serializer.NewError(serializer.CodeInternalSetting, "Invalid share url", err) } - } - - // 取得下载地址 - downloadURL, err := fs.GetDownloadURL(ctx, 0, "download_timeout") - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - Data: downloadURL, - } -} - -// PreviewContent 预览文件,需要登录会话, isText - 是否为文本文件,文本文件会 -// 强制经由服务端中转 -func (service *Service) PreviewContent(ctx context.Context, c *gin.Context, isText bool) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - // 用于调下层service - if share.IsDir { - ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.Source()) - ctx = context.WithValue(ctx, fsctx.PathCtx, service.Path) - } else { - ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.Source()) - } - subService := explorer.FileIDService{} - - return subService.PreviewContent(ctx, c, isText) -} - -// CreateDocPreviewSession 创建Office预览会话,返回预览地址 -func (service *Service) CreateDocPreviewSession(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - // 用于调下层service - ctx := context.Background() - if share.IsDir { - ctx = context.WithValue(ctx, fsctx.FolderModelCtx, share.Source()) - ctx = context.WithValue(ctx, fsctx.PathCtx, service.Path) - } else { - ctx = context.WithValue(ctx, fsctx.FileModelCtx, share.Source()) - } - subService := explorer.FileIDService{} - - return subService.CreateDocPreviewSession(ctx, c, false) -} - -// List 列出分享的目录下的对象 -func (service *Service) List(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - if !share.IsDir { - return serializer.ParamErr("This is not a shared folder", nil) - } - if !path.IsAbs(service.Path) { - return serializer.ParamErr("Invalid path", nil) - } - - // 创建文件系统 - fs, err := filesystem.NewFileSystem(share.Creator()) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 重设根目录 - fs.Root = share.Source().(*model.Folder) - fs.Root.Name = "/" - - // 分享Key上下文 - ctx = context.WithValue(ctx, fsctx.ShareKeyCtx, hashid.HashID(share.ID, hashid.ShareID)) - - // 获取子项目 - objects, err := fs.List(ctx, service.Path, nil) - if err != nil { - return serializer.Err(serializer.CodeNotSet, err.Error(), err) - } - - return serializer.Response{ - Code: 0, - Data: serializer.BuildObjectList(0, objects, nil), - } -} - -// Thumb 获取被分享文件的缩略图 -func (service *Service) Thumb(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - if !share.IsDir { - return serializer.ParamErr("This share has no thumb", nil) - } - - // 创建文件系统 - fs, err := filesystem.NewFileSystem(share.Creator()) - if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 重设根目录 - fs.Root = share.Source().(*model.Folder) - - // 找到缩略图的父目录 - exist, parent := fs.IsPathExist(service.Path) - if !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, parent) - - // 获取文件ID - fileID, err := hashid.DecodeHashID(c.Param("file"), hashid.FileID) - if err != nil { - return serializer.Err(serializer.CodeNotFound, "", err) - } - - // 获取缩略图 - resp, err := fs.GetThumb(ctx, uint(fileID)) - if err != nil { - return serializer.Err(serializer.CodeNotSet, "Failed to get thumb", err) - } + root, err := m.Get(c, shareUri) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "File not found", err) + } - if resp.Redirect { - c.Header("Cache-Control", fmt.Sprintf("max-age=%d", resp.MaxAge)) - c.Redirect(http.StatusMovedPermanently, resp.URL) - return serializer.Response{Code: -1} + res.SourceUri = root.Uri(true).String() } - defer resp.Content.Close() - http.ServeContent(c.Writer, c.Request, "thumb.png", fs.FileTarget[0].UpdatedAt, resp.Content) - - return serializer.Response{Code: -1} + return res, nil } -// Archive 创建批量下载归档 -func (service *ArchiveService) Archive(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - userCtx, _ := c.Get("user") - user := userCtx.(*model.User) - - // 是否有权限 - if !user.Group.OptionsSerialized.ArchiveDownload { - return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) - } - - if !share.IsDir { - return serializer.ParamErr("This share cannot be batch downloaded", nil) +type ( + ListShareService struct { + PageSize int `form:"page_size" binding:"required,min=10,max=100"` + OrderBy string `uri:"order_by" form:"order_by" json:"order_by"` + OrderDirection string `uri:"order_direction" form:"order_direction" json:"order_direction"` + NextPageToken string `form:"next_page_token"` } + ListShareParamCtx struct{} +) - // 创建文件系统 - fs, err := filesystem.NewFileSystem(user) +func (s *ListShareService) List(c *gin.Context) (*ListShareResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + shareClient := dep.ShareClient() + + args := &inventory.ListShareArgs{ + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageToken: s.NextPageToken, + PageSize: s.PageSize, + Order: inventory.OrderDirection(s.OrderDirection), + OrderBy: s.OrderBy, + }, + UserID: user.ID, + } + + ctx := context.WithValue(c, inventory.LoadShareUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareFile{}, true) + res, err := shareClient.List(ctx, args) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) - } - defer fs.Recycle() - - // 重设根目录 - fs.Root = share.Source().(*model.Folder) - - // 找到要打包文件的父目录 - exist, parent := fs.IsPathExist(service.Path) - if !exist { - return serializer.Err(serializer.CodeParentNotExist, "", nil) - } - - // 限制操作范围为父目录下 - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, parent) - - // 用于调下层service - tempUser := share.Creator() - tempUser.Group.OptionsSerialized.ArchiveDownload = true - c.Set("user", tempUser) - - subService := explorer.ItemIDService{ - Dirs: service.Dirs, - Items: service.Items, + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list shares", err) } - return subService.Archive(ctx, c) + base := dep.SettingProvider().SiteURL(ctx) + return BuildListShareResponse(res, hasher, base, user, true), nil } -// SearchService 对分享的目录进行搜索 -type SearchService struct { - explorer.ItemSearchService -} - -// Search 执行搜索 -func (service *SearchService) Search(c *gin.Context) serializer.Response { - shareCtx, _ := c.Get("share") - share := shareCtx.(*model.Share) - - if !share.IsDir { - return serializer.ParamErr("此分享无法列目录", nil) - } +func (s *ListShareService) ListInUserProfile(c *gin.Context, uid int) (*ListShareResponse, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) + hasher := dep.HashIDEncoder() + shareClient := dep.ShareClient() - if service.Path != "" && !path.IsAbs(service.Path) { - return serializer.ParamErr("路径无效", nil) + args := &inventory.ListShareArgs{ + PaginationArgs: &inventory.PaginationArgs{ + UseCursorPagination: true, + PageToken: s.NextPageToken, + PageSize: s.PageSize, + Order: inventory.OrderDirection(s.OrderDirection), + OrderBy: s.OrderBy, + }, + UserID: uid, + PublicOnly: true, } - // 创建文件系统 - fs, err := filesystem.NewFileSystem(share.Creator()) + ctx := context.WithValue(c, inventory.LoadShareUser{}, true) + ctx = context.WithValue(ctx, inventory.LoadShareFile{}, true) + res, err := shareClient.List(ctx, args) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list shares", err) } - defer fs.Recycle() - - // 上下文 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 重设根目录 - fs.Root = share.Source().(*model.Folder) - fs.Root.Name = "/" - if service.Path != "" { - ok, parent := fs.IsPathExist(service.Path) - if !ok { - return serializer.Err(serializer.CodeParentNotExist, "Cannot find parent folder", nil) - } - - fs.Root = parent - } - - // 分享Key上下文 - ctx = context.WithValue(ctx, fsctx.ShareKeyCtx, hashid.HashID(share.ID, hashid.ShareID)) - return service.SearchKeywords(c, fs, "%"+service.Keywords+"%") + base := dep.SettingProvider().SiteURL(ctx) + return BuildListShareResponse(res, hasher, base, user, false), nil } diff --git a/service/user/info.go b/service/user/info.go new file mode 100644 index 00000000..588beb64 --- /dev/null +++ b/service/user/info.go @@ -0,0 +1,67 @@ +package user + +import ( + "context" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs" + "github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/gin-gonic/gin" + "github.com/samber/lo" +) + +func GetUser(c *gin.Context) (*ent.User, error) { + uid := hashid.FromContext(c) + dep := dependency.FromContext(c) + userClient := dep.UserClient() + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + return userClient.GetByID(ctx, uid) +} + +func GetUserCapacity(c *gin.Context) (*fs.Capacity, error) { + user := inventory.UserFromContext(c) + dep := dependency.FromContext(c) + m := manager.NewFileManager(dep, user) + defer m.Recycle() + + return m.Capacity(c) +} + +type ( + SearchUserService struct { + Keyword string `form:"keyword" binding:"required,min=2"` + } + SearchUserParamCtx struct{} +) + +const resultLimit = 10 + +func (s *SearchUserService) Search(c *gin.Context) ([]*ent.User, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + res, err := userClient.SearchActive(c, resultLimit, s.Keyword) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to search user", err) + } + + return res, nil +} + +// ListAllGroups lists all groups. +func ListAllGroups(c *gin.Context) ([]*ent.Group, error) { + dep := dependency.FromContext(c) + groupClient := dep.GroupClient() + res, err := groupClient.ListAll(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list all groups", err) + } + + res = lo.Filter(res, func(g *ent.Group, index int) bool { + return g.ID != inventory.AnonymousGroupID + }) + + return res, nil +} diff --git a/service/user/login.go b/service/user/login.go index 22649dde..573563c2 100644 --- a/service/user/login.go +++ b/service/user/login.go @@ -1,203 +1,245 @@ package user import ( + "context" "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/email" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" "github.com/gofrs/uuid" "github.com/pquerna/otp/totp" - "net/url" ) +// LoginParameterCtx define key fore UserLoginService +type LoginParameterCtx struct{} + // UserLoginService 管理用户登录的服务 type UserLoginService struct { - //TODO 细致调整验证规则 - UserName string `form:"userName" json:"userName" binding:"required,email"` - Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"` -} - -// UserResetEmailService 发送密码重设邮件服务 -type UserResetEmailService struct { - UserName string `form:"userName" json:"userName" binding:"required,email"` + UserName string `form:"email" json:"email" binding:"required,email"` + Password string `form:"password" json:"password" binding:"required,min=4,max=64"` } -// UserResetService 密码重设服务 -type UserResetService struct { - Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"` - ID string `json:"id" binding:"required"` - Secret string `json:"secret" binding:"required"` -} +type ( + // UserResetService 密码重设服务 + UserResetService struct { + Password string `form:"password" json:"password" binding:"required,min=6,max=64"` + Secret string `json:"secret" binding:"required"` + } + UserResetParameterCtx struct{} +) // Reset 重设密码 -func (service *UserResetService) Reset(c *gin.Context) serializer.Response { - // 取得原始用户ID - uid, err := hashid.DecodeHashID(service.ID, hashid.UserID) - if err != nil { - return serializer.Err(serializer.CodeInvalidTempLink, "Invalid link", err) +func (service *UserResetService) Reset(c *gin.Context) (*User, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + kv := dep.KV() + uid := hashid.FromContext(c) + + resetSession, ok := kv.Get(fmt.Sprintf("user_reset_%d", uid)) + if !ok || resetSession.(string) != service.Secret { + return nil, serializer.NewError(serializer.CodeTempLinkExpired, "Link is expired", nil) } - // 检查重设会话 - resetSession, exist := cache.Get(fmt.Sprintf("user_reset_%d", uid)) - if !exist || resetSession.(string) != service.Secret { - return serializer.Err(serializer.CodeTempLinkExpired, "Link is expired", err) + if err := kv.Delete(fmt.Sprintf("user_reset_%d", uid)); err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to delete reset session", err) } - // 重设用户密码 - user, err := model.GetActiveUserByID(uid) + u, err := userClient.GetActiveByID(c, uid) if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "User not found", nil) + return nil, serializer.NewError(serializer.CodeUserNotFound, "User not found", err) } - user.SetPassword(service.Password) - if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil { - return serializer.DBErr("Failed to reset password", err) + u, err = userClient.UpdatePassword(c, u, service.Password) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to update password", err) } - cache.Deletes([]string{fmt.Sprintf("%d", uid)}, "user_reset_") - return serializer.Response{} + userRes := BuildUser(u, dep.HashIDEncoder()) + return &userRes, nil } +type ( + // UserResetEmailService 发送密码重设邮件服务 + UserResetEmailService struct { + UserName string `form:"email" json:"email" binding:"required,email"` + } + UserResetEmailParameterCtx struct{} +) + +const userResetPrefix = "user_reset_" + // Reset 发送密码重设邮件 -func (service *UserResetEmailService) Reset(c *gin.Context) serializer.Response { - // 查找用户 - if user, err := model.GetUserByEmail(service.UserName); err == nil { +func (service *UserResetEmailService) Reset(c *gin.Context) error { + dep := dependency.FromContext(c) + userClient := dep.UserClient() - if user.Status == model.Baned || user.Status == model.OveruseBaned { - return serializer.Err(serializer.CodeUserBaned, "This user is banned", nil) - } - if user.Status == model.NotActivicated { - return serializer.Err(serializer.CodeUserNotActivated, "This user is not activated", nil) - } - // 创建密码重设会话 - secret := util.RandStringRunes(32) - cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600) - - // 生成用户访问的重设链接 - controller, _ := url.Parse("/reset") - finalURL := model.GetSiteURL().ResolveReference(controller) - queries := finalURL.Query() - queries.Add("id", hashid.HashID(user.ID, hashid.UserID)) - queries.Add("sign", secret) - finalURL.RawQuery = queries.Encode() - - // 发送密码重设邮件 - title, body := email.NewResetEmail(user.Nick, finalURL.String()) - if err := email.Send(user.Email, title, body); err != nil { - return serializer.Err(serializer.CodeFailedSendEmail, "Failed to send email", err) - } + u, err := userClient.GetByEmail(c, service.UserName) + if err != nil { + return serializer.NewError(serializer.CodeUserNotFound, "User not found", err) + } + if u.Status == user.StatusManualBanned || u.Status == user.StatusSysBanned { + return serializer.NewError(serializer.CodeUserBaned, "This user is banned", nil) } - return serializer.Response{} -} + if u.Status == user.StatusInactive { + return serializer.NewError(serializer.CodeUserNotActivated, "This user is not activated", nil) + } -// Login 二步验证继续登录 -func (service *Enable2FA) Login(c *gin.Context) serializer.Response { - if uid, ok := util.GetSession(c, "2fa_user_id").(uint); ok { - // 查找用户 - expectedUser, err := model.GetActiveUserByID(uid) - if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "User not found", nil) - } + secret := util.RandStringRunes(32) + if err := dep.KV().Set(fmt.Sprintf("%s%d", userResetPrefix, u.ID), secret, 3600); err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to create reset session", err) + } - // 验证二步验证代码 - if !totp.Validate(service.Code, expectedUser.TwoFactor) { - return serializer.Err(serializer.Code2FACodeErr, "2FA code not correct", nil) - } + base := dep.SettingProvider().SiteURL(c) + resetUrl := routes.MasterUserResetUrl(base) + queries := resetUrl.Query() + queries.Add("id", hashid.EncodeUserID(dep.HashIDEncoder(), u.ID)) + queries.Add("secret", secret) + resetUrl.RawQuery = queries.Encode() - //登陆成功,清空并设置session - util.DeleteSession(c, "2fa_user_id") - util.SetSession(c, map[string]interface{}{ - "user_id": expectedUser.ID, - }) + title, body, err := email.NewResetEmail(c, dep.SettingProvider(), u, resetUrl.String()) + if err != nil { + return serializer.NewError(serializer.CodeFailedSendEmail, "Failed to send activation email", err) + } - return serializer.BuildUserResponse(expectedUser) + if err := dep.EmailClient(c).Send(c, u.Email, title, body); err != nil { + return serializer.NewError(serializer.CodeFailedSendEmail, "Failed to send activation email", err) } - return serializer.Err(serializer.CodeLoginSessionNotExist, "Login session not exist", nil) + return nil } // Login 用户登录函数 -func (service *UserLoginService) Login(c *gin.Context) serializer.Response { - expectedUser, err := model.GetUserByEmail(service.UserName) +func (service *UserLoginService) Login(c *gin.Context) (*ent.User, string, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + expectedUser, err := userClient.GetByEmail(ctx, service.UserName) + // 一系列校验 if err != nil { - return serializer.Err(serializer.CodeCredentialInvalid, "Wrong password or email address", err) - } - if authOK, _ := expectedUser.CheckPassword(service.Password); !authOK { - return serializer.Err(serializer.CodeCredentialInvalid, "Wrong password or email address", nil) + err = serializer.NewError(serializer.CodeInvalidPassword, "Incorrect password or email address", err) + } else if checkErr := inventory.CheckPassword(expectedUser, service.Password); checkErr != nil { + err = serializer.NewError(serializer.CodeInvalidPassword, "Incorrect password or email address", err) + } else if expectedUser.Status == user.StatusManualBanned || expectedUser.Status == user.StatusSysBanned { + err = serializer.NewError(serializer.CodeUserBaned, "This account has been blocked", nil) + } else if expectedUser.Status == user.StatusInactive { + err = serializer.NewError(serializer.CodeUserNotActivated, "This account is not activated", nil) } - if expectedUser.Status == model.Baned || expectedUser.Status == model.OveruseBaned { - return serializer.Err(serializer.CodeUserBaned, "This account has been blocked", nil) + + if err != nil { + return nil, "", err } - if expectedUser.Status == model.NotActivicated { - return serializer.Err(serializer.CodeUserNotActivated, "This account is not activated", nil) + + if expectedUser.TwoFactorSecret != "" { + twoFaSessionID := uuid.Must(uuid.NewV4()) + dep.KV().Set(fmt.Sprintf("user_2fa_%s", twoFaSessionID), expectedUser.ID, 600) + return expectedUser, twoFaSessionID.String(), nil } - if expectedUser.TwoFactor != "" { - // 需要二步验证 - util.SetSession(c, map[string]interface{}{ - "2fa_user_id": expectedUser.ID, - }) - return serializer.Response{Code: 203} + return expectedUser, "", nil +} + +type ( + LoginLogCtx struct{} +) + +func IssueToken(c *gin.Context) (*BuiltinLoginResponse, error) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + token, err := dep.TokenAuth().Issue(c, u) + if err != nil { + return nil, serializer.NewError(serializer.CodeEncryptError, "Failed to issue token pair", err) } - //登陆成功,清空并设置session - util.SetSession(c, map[string]interface{}{ - "user_id": expectedUser.ID, - }) + return &BuiltinLoginResponse{ + User: BuildUser(u, dep.HashIDEncoder()), + Token: *token, + }, nil +} - return serializer.BuildUserResponse(expectedUser) +// RefreshTokenParameterCtx define key fore RefreshTokenService +type RefreshTokenParameterCtx struct{} +// RefreshTokenService refresh token service +type RefreshTokenService struct { + RefreshToken string `json:"refresh_token" binding:"required"` } -// CopySessionService service for copy user session -type CopySessionService struct { - ID string `uri:"id" binding:"required,uuid4"` +func (s *RefreshTokenService) Refresh(c *gin.Context) (*auth.Token, error) { + dep := dependency.FromContext(c) + token, err := dep.TokenAuth().Refresh(c, s.RefreshToken) + if err != nil { + return nil, serializer.NewError(serializer.CodeCredentialInvalid, "Failed to issue token pair", err) + } + + return token, nil } -const CopySessionTTL = 60 +type ( + OtpValidationParameterCtx struct{} + OtpValidationService struct { + OTP string `json:"otp" binding:"required"` + SessionID string `json:"session_id" binding:"required"` + } +) + +// Login 用户登录函数 +func (service *OtpValidationService) Verify2FA(c *gin.Context) (*ent.User, error) { + dep := dependency.FromContext(c) + kv := dep.KV() -// Prepare generates the URL with short expiration duration -func (s *CopySessionService) Prepare(c *gin.Context, user *model.User) serializer.Response { - // 用户组有效期 - urlID := uuid.Must(uuid.NewV4()) - if err := cache.Set(fmt.Sprintf("copy_session_%s", urlID.String()), user.ID, CopySessionTTL); err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to create copy session", err) + sessionRaw, ok := kv.Get(fmt.Sprintf("user_2fa_%s", service.SessionID)) + if !ok { + return nil, serializer.NewError(serializer.CodeNotFound, "Session not found", nil) } - base := model.GetSiteURL() - apiBaseURI, _ := url.Parse("/api/v3/user/session/copy/" + urlID.String()) - apiURL := base.ResolveReference(apiBaseURI) - res, err := auth.SignURI(auth.General, apiURL.String(), CopySessionTTL) + uid := sessionRaw.(int) + ctx := context.WithValue(c, inventory.LoadUserGroup{}, true) + expectedUser, err := dep.UserClient().GetByID(ctx, uid) if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to sign temp URL", err) + return nil, serializer.NewError(serializer.CodeNotFound, "User not found", err) } - return serializer.Response{ - Data: res.String(), + if expectedUser.TwoFactorSecret != "" { + if !totp.Validate(service.OTP, expectedUser.TwoFactorSecret) { + err := serializer.NewError(serializer.Code2FACodeErr, "Incorrect 2FA code", nil) + return nil, err + } } + + kv.Delete("user_2fa_", service.SessionID) + return expectedUser, nil } -// Copy a new session from active session, refresh max-age -func (s *CopySessionService) Copy(c *gin.Context) serializer.Response { - // 用户组有效期 - cacheKey := fmt.Sprintf("copy_session_%s", s.ID) - uid, ok := cache.Get(cacheKey) - if !ok { - return serializer.Err(serializer.CodeNotFound, "", nil) +type ( + PrepareLoginParameterCtx struct{} + PrepareLoginService struct { + Email string `form:"email" binding:"required,email"` } +) - cache.Deletes([]string{cacheKey}, "") - util.SetSession(c, map[string]interface{}{ - "user_id": uid.(uint), - }) +func (service *PrepareLoginService) Prepare(c *gin.Context) (*PrepareLoginResponse, error) { + dep := dependency.FromContext(c) + ctx := context.WithValue(c, inventory.LoadUserPasskey{}, true) + expectedUser, err := dep.UserClient().GetByEmail(ctx, service.Email) + if err != nil { + return nil, serializer.NewError(serializer.CodeNotFound, "User not found", err) + } - return serializer.Response{} + return &PrepareLoginResponse{ + WebAuthnEnabled: len(expectedUser.Edges.Passkey) > 0, + PasswordEnabled: expectedUser.Password != "", + }, nil } diff --git a/service/user/passkey.go b/service/user/passkey.go new file mode 100644 index 00000000..b2383623 --- /dev/null +++ b/service/user/passkey.go @@ -0,0 +1,288 @@ +package user + +import ( + "context" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-gonic/gin" + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" + "github.com/samber/lo" + "strconv" + "strings" +) + +func init() { + gob.Register(webauthn.SessionData{}) +} + +type authnUser struct { + hasher hashid.Encoder + u *ent.User + credentials []*ent.Passkey +} + +func (a *authnUser) WebAuthnID() []byte { + return []byte(hashid.EncodeUserID(a.hasher, a.u.ID)) +} + +func (a *authnUser) WebAuthnName() string { + return a.u.Email +} + +func (a *authnUser) WebAuthnDisplayName() string { + return a.u.Nick +} + +func (a *authnUser) WebAuthnCredentials() []webauthn.Credential { + if a.credentials == nil { + return nil + } + + return lo.Map(a.credentials, func(item *ent.Passkey, index int) webauthn.Credential { + return *item.Credential + }) +} + +const ( + authnSessionKey = "authn_session_" +) + +func PreparePasskeyLogin(c *gin.Context) (*PreparePasskeyLoginResponse, error) { + dep := dependency.FromContext(c) + webAuthn, err := dep.WebAuthn(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize WebAuthn", err) + } + + options, sessionData, err := webAuthn.BeginDiscoverableLogin() + if err != nil { + return nil, serializer.NewError(serializer.CodeInitializeAuthn, "Failed to begin registration", err) + } + + sessionID := uuid.Must(uuid.NewV4()).String() + if err := dep.KV().Set(fmt.Sprint("%s%s", authnSessionKey, sessionID), *sessionData, 300); err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to store session data", err) + } + + return &PreparePasskeyLoginResponse{ + Options: options, + SessionID: sessionID, + }, nil +} + +type ( + FinishPasskeyLoginParameterCtx struct{} + FinishPasskeyLoginService struct { + Response string `json:"response" binding:"required"` + SessionID string `json:"session_id" binding:"required"` + } +) + +func (s *FinishPasskeyLoginService) FinishPasskeyLogin(c *gin.Context) (*ent.User, error) { + dep := dependency.FromContext(c) + kv := dep.KV() + userClient := dep.UserClient() + + sessionDataRaw, ok := kv.Get(fmt.Sprint("%s%s", authnSessionKey, s.SessionID)) + if !ok { + return nil, serializer.NewError(serializer.CodeNotFound, "Session not found", nil) + } + + _ = kv.Delete(authnSessionKey, s.Response) + + webAuthn, err := dep.WebAuthn(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize WebAuthn", err) + } + + sessionData := sessionDataRaw.(webauthn.SessionData) + pcc, err := protocol.ParseCredentialRequestResponseBody(strings.NewReader(s.Response)) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Failed to parse request", err) + } + + var loginedUser *ent.User + discoverUserHandle := func(rawID, userHandle []byte) (user webauthn.User, err error) { + uid, err := dep.HashIDEncoder().Decode(string(userHandle), hashid.UserID) + if err != nil { + return nil, err + } + + ctx := context.WithValue(c, inventory.LoadUserPasskey{}, true) + ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true) + u, err := userClient.GetLoginUserByID(ctx, uid) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user", err) + } + + if inventory.IsAnonymousUser(u) { + return nil, errors.New("anonymous user") + } + + loginedUser = u + return &authnUser{u: u, hasher: dep.HashIDEncoder(), credentials: u.Edges.Passkey}, nil + } + + credential, err := webAuthn.ValidateDiscoverableLogin(discoverUserHandle, sessionData, pcc) + if err != nil { + return nil, serializer.NewError(serializer.CodeWebAuthnCredentialError, "Failed to validate login", err) + } + + // Find the credential just used + usedCredentialId := base64.StdEncoding.EncodeToString(credential.ID) + usedCredential, found := lo.Find(loginedUser.Edges.Passkey, func(item *ent.Passkey) bool { + return item.CredentialID == usedCredentialId + }) + + if !found { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Passkey login passed but credential used is unknown", nil) + } + + // Update used at + if err := userClient.MarkPasskeyUsed(c, loginedUser.ID, usedCredential.CredentialID); err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to update passkey", err) + } + + return loginedUser, nil +} + +func PreparePasskeyRegister(c *gin.Context) (*protocol.CredentialCreation, error) { + dep := dependency.FromContext(c) + userClient := dep.UserClient() + u := inventory.UserFromContext(c) + + existingKeys, err := userClient.ListPasskeys(c, u.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to list passkeys", err) + } + + webAuthn, err := dep.WebAuthn(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize WebAuthn", err) + } + + authSelect := protocol.AuthenticatorSelection{ + RequireResidentKey: protocol.ResidentKeyRequired(), + UserVerification: protocol.VerificationPreferred, + } + + options, sessionData, err := webAuthn.BeginRegistration( + &authnUser{u: u, hasher: dep.HashIDEncoder()}, + webauthn.WithAuthenticatorSelection(authSelect), + webauthn.WithExclusions(lo.Map(existingKeys, func(item *ent.Passkey, index int) protocol.CredentialDescriptor { + return protocol.CredentialDescriptor{ + Type: protocol.PublicKeyCredentialType, + CredentialID: item.Credential.ID, + Transport: item.Credential.Transport, + AttestationType: item.Credential.AttestationType, + } + })), + ) + if err != nil { + return nil, serializer.NewError(serializer.CodeInitializeAuthn, "Failed to begin registration", err) + } + + if err := dep.KV().Set(fmt.Sprint("%s%d", authnSessionKey, u.ID), *sessionData, 300); err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to store session data", err) + } + + return options, nil +} + +type ( + FinishPasskeyRegisterParameterCtx struct{} + FinishPasskeyRegisterService struct { + Response string `json:"response" binding:"required"` + Name string `json:"name" binding:"required"` + UA string `json:"ua" binding:"required"` + } +) + +func (s *FinishPasskeyRegisterService) FinishPasskeyRegister(c *gin.Context) (*Passkey, error) { + dep := dependency.FromContext(c) + kv := dep.KV() + u := inventory.UserFromContext(c) + + sessionDataRaw, ok := kv.Get(fmt.Sprint("%s%d", authnSessionKey, u.ID)) + if !ok { + return nil, serializer.NewError(serializer.CodeNotFound, "Session not found", nil) + } + + _ = kv.Delete(authnSessionKey, strconv.Itoa(u.ID)) + + webAuthn, err := dep.WebAuthn(c) + if err != nil { + return nil, serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize WebAuthn", err) + } + + sessionData := sessionDataRaw.(webauthn.SessionData) + pcc, err := protocol.ParseCredentialCreationResponseBody(strings.NewReader(s.Response)) + if err != nil { + return nil, serializer.NewError(serializer.CodeParamErr, "Failed to parse request", err) + } + + credential, err := webAuthn.CreateCredential(&authnUser{u: u, hasher: dep.HashIDEncoder()}, sessionData, pcc) + if err != nil { + return nil, serializer.NewError(serializer.CodeWebAuthnCredentialError, "Failed to finish registration", err) + } + + client := dep.UAParser().Parse(s.UA) + name := util.Replace(map[string]string{ + "{os}": client.Os.Family, + "{browser}": client.UserAgent.Family, + }, s.Name) + + passkey, err := dep.UserClient().AddPasskey(c, u.ID, name, credential) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to add passkey", err) + } + + res := BuildPasskey(passkey) + return &res, nil +} + +type ( + DeletePasskeyService struct { + ID string `form:"id" binding:"required"` + } + DeletePasskeyParameterCtx struct{} +) + +func (s *DeletePasskeyService) DeletePasskey(c *gin.Context) error { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + userClient := dep.UserClient() + + existingKeys, err := userClient.ListPasskeys(c, u.ID) + if err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to list passkeys", err) + } + + var existing *ent.Passkey + for _, key := range existingKeys { + if key.CredentialID == s.ID { + existing = key + break + } + } + + if existing == nil { + return serializer.NewError(serializer.CodeNotFound, "Passkey not found", nil) + } + + if err := userClient.RemovePasskey(c, u.ID, s.ID); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to delete passkey", err) + } + + return nil +} diff --git a/service/user/register.go b/service/user/register.go index 35e8253d..9088c7dc 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -1,113 +1,148 @@ package user import ( - "net/url" + "context" + "errors" "strings" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "time" + + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes" + "github.com/cloudreve/Cloudreve/v4/pkg/email" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/util" "github.com/gin-gonic/gin" ) +// RegisterParameterCtx define key fore UserRegisterService +type RegisterParameterCtx struct{} + // UserRegisterService 管理用户注册的服务 type UserRegisterService struct { - //TODO 细致调整验证规则 - UserName string `form:"userName" json:"userName" binding:"required,email"` - Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"` + UserName string `form:"email" json:"email" binding:"required,email"` + Password string `form:"password" json:"password" binding:"required,min=6,max=64"` + Language string `form:"language" json:"language"` } // Register 新用户注册 func (service *UserRegisterService) Register(c *gin.Context) serializer.Response { - // 相关设定 - options := model.GetSettingByNames("email_active") - - // 相关设定 - isEmailRequired := model.IsTrueVal(options["email_active"]) - defaultGroup := model.GetIntSetting("default_group", 2) - - // 创建新的用户对象 - user := model.NewUser() - user.Email = service.UserName - user.Nick = strings.Split(service.UserName, "@")[0] - user.SetPassword(service.Password) - user.Status = model.Active + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + + isEmailRequired := settings.EmailActivationEnabled(c) + args := &inventory.NewUserArgs{ + Email: strings.ToLower(service.UserName), + PlainPassword: service.Password, + Status: user.StatusActive, + GroupID: settings.DefaultGroup(c), + Language: service.Language, + } if isEmailRequired { - user.Status = model.NotActivicated + args.Status = user.StatusInactive } - user.GroupID = uint(defaultGroup) - userNotActivated := false - // 创建用户 - if err := model.DB.Create(&user).Error; err != nil { - //检查已存在使用者是否尚未激活 - expectedUser, err := model.GetUserByEmail(service.UserName) - if expectedUser.Status == model.NotActivicated { - userNotActivated = true - user = expectedUser - } else { - return serializer.Err(serializer.CodeEmailExisted, "Email already in use", err) - } + + userClient := dep.UserClient() + uc, tx, _, err := inventory.WithTx(c, userClient) + if err != nil { + return serializer.DBErr(c, "Failed to start transaction", err) } - // 发送激活邮件 - if isEmailRequired { + expectedUser, err := uc.Create(c, args) + if expectedUser != nil { + util.WithValue(c, inventory.UserCtx{}, expectedUser) + } - // 签名激活请求API - base := model.GetSiteURL() - userID := hashid.HashID(user.ID, hashid.UserID) - controller, _ := url.Parse("/api/v3/user/activate/" + userID) - activateURL, err := auth.SignURI(auth.General, base.ResolveReference(controller).String(), 86400) - if err != nil { - return serializer.Err(serializer.CodeEncryptError, "Failed to sign the activation link", err) + if err != nil { + _ = inventory.Rollback(tx) + if errors.Is(err, inventory.ErrUserEmailExisted) { + return serializer.ErrWithDetails(c, serializer.CodeEmailExisted, "Email already in use", err) } - // 取得签名 - credential := activateURL.Query().Get("sign") - - // 生成对用户访问的激活地址 - controller, _ = url.Parse("/activate") - finalURL := base.ResolveReference(controller) - queries := finalURL.Query() - queries.Add("id", userID) - queries.Add("sign", credential) - finalURL.RawQuery = queries.Encode() - - // 返送激活邮件 - title, body := email.NewActivationEmail(user.Email, - finalURL.String(), - ) - if err := email.Send(user.Email, title, body); err != nil { - return serializer.Err(serializer.CodeFailedSendEmail, "Failed to send activation email", err) + if errors.Is(err, inventory.ErrInactiveUserExisted) { + if err := sendActivationEmail(c, dep, expectedUser); err != nil { + return serializer.ErrWithDetails(c, serializer.CodeNotSet, "", err) + } + + return serializer.ErrWithDetails(c, serializer.CodeEmailSent, "User is not activated, activation email has been resent", nil) } - if userNotActivated == true { - //原本在上面要抛出的DBErr,放来这边抛出 - return serializer.Err(serializer.CodeEmailSent, "User is not activated, activation email has been resent", nil) - } else { - return serializer.Response{Code: 203} + + return serializer.DBErr(c, "Failed to insert user row", err) + } + + if err := inventory.Commit(tx); err != nil { + return serializer.DBErr(c, "Failed to commit user row", err) + } + + if isEmailRequired { + if err := sendActivationEmail(c, dep, expectedUser); err != nil { + return serializer.ErrWithDetails(c, serializer.CodeNotSet, "", err) } + return serializer.Response{Code: serializer.CodeNotFullySuccess} } - return serializer.Response{} + return serializer.Response{Data: BuildUser(expectedUser, dep.HashIDEncoder())} } -// Activate 激活用户 -func (service *SettingService) Activate(c *gin.Context) serializer.Response { +func sendActivationEmail(ctx context.Context, dep dependency.Dep, newUser *ent.User) error { + base := dep.SettingProvider().SiteURL(ctx) + userID := hashid.EncodeUserID(dep.HashIDEncoder(), newUser.ID) + ttl := time.Now().Add(time.Duration(24) * time.Hour) + activateURL, err := auth.SignURI(ctx, dep.GeneralAuth(), routes.MasterUserActivateAPIUrl(base, userID).String(), &ttl) + if err != nil { + return serializer.NewError(serializer.CodeEncryptError, "Failed to sign the activation link", err) + } + + // 取得签名 + credential := activateURL.Query().Get("sign") + + // 生成对用户访问的激活地址 + finalURL := routes.MasterUserActivateUrl(base) + queries := finalURL.Query() + queries.Add("id", userID) + queries.Add("sign", credential) + finalURL.RawQuery = queries.Encode() + + // 返送激活邮件 + title, body, err := email.NewActivationEmail(ctx, dep.SettingProvider(), newUser, finalURL.String()) + if err != nil { + return serializer.NewError(serializer.CodeFailedSendEmail, "Failed to send activation email", err) + } + + if err := dep.EmailClient(ctx).Send(ctx, newUser.Email, title, body); err != nil { + return serializer.NewError(serializer.CodeFailedSendEmail, "Failed to send activation email", err) + } + + return nil +} + +// ActivateUser 激活用户 +func ActivateUser(c *gin.Context) serializer.Response { + uid := hashid.FromContext(c) + dep := dependency.FromContext(c) + userClient := dep.UserClient() + // 查找待激活用户 - uid, _ := c.Get("object_id") - user, err := model.GetUserByID(uid.(uint)) + inactiveUser, err := userClient.GetByID(c, uid) if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "User not fount", err) + return serializer.ErrWithDetails(c, serializer.CodeUserNotFound, "User not fount", err) } // 检查状态 - if user.Status != model.NotActivicated { - return serializer.Err(serializer.CodeUserCannotActivate, "This user cannot be activated", nil) + if inactiveUser.Status != user.StatusInactive { + return serializer.ErrWithDetails(c, serializer.CodeUserCannotActivate, "This user cannot be activated", nil) } // 激活用户 - user.SetStatus(model.Active) + activeUser, err := userClient.SetStatus(c, inactiveUser, user.StatusActive) + if err != nil { + return serializer.DBErr(c, "Failed to update user", err) + } - return serializer.Response{Data: user.Email} + util.WithValue(c, inventory.UserCtx{}, activeUser) + return serializer.Response{Data: BuildUser(activeUser, dep.HashIDEncoder())} } diff --git a/service/user/response.go b/service/user/response.go new file mode 100644 index 00000000..d4ceda39 --- /dev/null +++ b/service/user/response.go @@ -0,0 +1,219 @@ +package user + +import ( + "fmt" + "time" + + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/ent/user" + "github.com/cloudreve/Cloudreve/v4/inventory/types" + "github.com/cloudreve/Cloudreve/v4/pkg/auth" + "github.com/cloudreve/Cloudreve/v4/pkg/boolset" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/samber/lo" + "github.com/ua-parser/uap-go/uaparser" +) + +type PreparePasskeyLoginResponse struct { + Options *protocol.CredentialAssertion `json:"options"` + SessionID string `json:"session_id"` +} + +type UserSettings struct { + VersionRetentionEnabled bool `json:"version_retention_enabled"` + VersionRetentionExt []string `json:"version_retention_ext,omitempty"` + VersionRetentionMax int `json:"version_retention_max,omitempty"` + Paswordless bool `json:"passwordless"` + TwoFAEnabled bool `json:"two_fa_enabled"` + Passkeys []Passkey `json:"passkeys,omitempty"` +} + +func BuildUserSettings(u *ent.User, passkeys []*ent.Passkey, parser *uaparser.Parser) *UserSettings { + return &UserSettings{ + VersionRetentionEnabled: u.Settings.VersionRetention, + VersionRetentionExt: u.Settings.VersionRetentionExt, + VersionRetentionMax: u.Settings.VersionRetentionMax, + TwoFAEnabled: u.TwoFactorSecret != "", + Paswordless: u.Password == "", + Passkeys: lo.Map(passkeys, func(item *ent.Passkey, index int) Passkey { + return BuildPasskey(item) + }), + } +} + +type Passkey struct { + ID string `json:"id"` + Name string `json:"name"` + UsedAt *time.Time `json:"used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +func BuildPasskey(passkey *ent.Passkey) Passkey { + return Passkey{ + ID: passkey.CredentialID, + Name: passkey.Name, + UsedAt: passkey.UsedAt, + CreatedAt: passkey.CreatedAt, + } +} + +// Node option for handling workflows. +type Node struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Capabilities *boolset.BooleanSet `json:"capabilities"` +} + +// BuildNodes serialize a list of nodes. +func BuildNodes(nodes []*ent.Node, idEncoder hashid.Encoder) []*Node { + res := make([]*Node, 0, len(nodes)) + for _, v := range nodes { + res = append(res, BuildNode(v, idEncoder)) + } + + return res +} + +// BuildNode serialize a node. +func BuildNode(node *ent.Node, idEncoder hashid.Encoder) *Node { + return &Node{ + ID: hashid.EncodeNodeID(idEncoder, node.ID), + Name: node.Name, + Type: string(node.Type), + Capabilities: node.Capabilities, + } +} + +// BuiltinLoginResponse response for a successful login for builtin auth provider. +type BuiltinLoginResponse struct { + User User `json:"user"` + Token auth.Token `json:"token"` +} + +// User 用户序列化器 +type User struct { + ID string `json:"id"` + Email string `json:"email,omitempty"` + Nickname string `json:"nickname"` + Status user.Status `json:"status,omitempty"` + Avatar string `json:"avatar,omitempty"` + CreatedAt time.Time `json:"created_at"` + PreferredTheme string `json:"preferred_theme,omitempty"` + Anonymous bool `json:"anonymous,omitempty"` + Group *Group `json:"group,omitempty"` + Pined []types.PinedFile `json:"pined,omitempty"` + Language string `json:"language,omitempty"` +} + +type Group struct { + ID string `json:"id"` + Name string `json:"name"` + Permission *boolset.BooleanSet `json:"permission,omitempty"` + DirectLinkBatchSize int `json:"direct_link_batch_size,omitempty"` + TrashRetention int `json:"trash_retention,omitempty"` +} + +type storage struct { + Used uint64 `json:"used"` + Free uint64 `json:"free"` + Total uint64 `json:"total"` +} + +// WebAuthnCredentials 外部验证器凭证 +type WebAuthnCredentials struct { + ID []byte `json:"id"` + FingerPrint string `json:"fingerprint"` +} + +type PrepareLoginResponse struct { + WebAuthnEnabled bool `json:"webauthn_enabled"` + PasswordEnabled bool `json:"password_enabled"` +} + +// BuildWebAuthnList 构建设置页面凭证列表 +func BuildWebAuthnList(credentials []webauthn.Credential) []WebAuthnCredentials { + res := make([]WebAuthnCredentials, 0, len(credentials)) + for _, v := range credentials { + credential := WebAuthnCredentials{ + ID: v.ID, + FingerPrint: fmt.Sprintf("% X", v.Authenticator.AAGUID), + } + res = append(res, credential) + } + + return res +} + +// BuildUser 序列化用户 +func BuildUser(user *ent.User, idEncoder hashid.Encoder) User { + return User{ + ID: hashid.EncodeUserID(idEncoder, user.ID), + Email: user.Email, + Nickname: user.Nick, + Status: user.Status, + Avatar: user.Avatar, + CreatedAt: user.CreatedAt, + PreferredTheme: user.Settings.PreferredTheme, + Anonymous: user.ID == 0, + Group: BuildGroup(user.Edges.Group, idEncoder), + Pined: user.Settings.Pined, + Language: user.Settings.Language, + } +} + +func BuildGroup(group *ent.Group, idEncoder hashid.Encoder) *Group { + if group == nil { + return nil + } + return &Group{ + ID: hashid.EncodeGroupID(idEncoder, group.ID), + Name: group.Name, + Permission: group.Permissions, + DirectLinkBatchSize: group.Settings.SourceBatchSize, + TrashRetention: group.Settings.TrashRetention, + } +} + +const sensitiveTag = "redacted" + +const ( + RedactLevelAnonymous = iota + RedactLevelUser +) + +// BuildUserRedacted Serialize a user without sensitive information. +func BuildUserRedacted(u *ent.User, level int, idEncoder hashid.Encoder) User { + userRaw := BuildUser(u, idEncoder) + + user := User{ + ID: userRaw.ID, + Nickname: userRaw.Nickname, + Avatar: userRaw.Avatar, + CreatedAt: userRaw.CreatedAt, + } + + if userRaw.Group != nil { + user.Group = RedactedGroup(userRaw.Group) + } + + if level == RedactLevelUser { + user.Email = userRaw.Email + } + + return user +} + +// BuildGroupRedacted Serialize a group without sensitive information. +func RedactedGroup(g *Group) *Group { + if g == nil { + return nil + } + + return &Group{ + ID: g.ID, + Name: g.Name, + } +} diff --git a/service/user/setting.go b/service/user/setting.go index 8d7f6191..a6dd7b66 100644 --- a/service/user/setting.go +++ b/service/user/setting.go @@ -1,256 +1,308 @@ package user import ( + "context" "crypto/md5" "fmt" + "github.com/cloudreve/Cloudreve/v4/application/dependency" + "github.com/cloudreve/Cloudreve/v4/ent" + "github.com/cloudreve/Cloudreve/v4/inventory" + "github.com/cloudreve/Cloudreve/v4/pkg/hashid" + "github.com/cloudreve/Cloudreve/v4/pkg/request" + "github.com/cloudreve/Cloudreve/v4/pkg/serializer" + "github.com/cloudreve/Cloudreve/v4/pkg/setting" + "github.com/cloudreve/Cloudreve/v4/pkg/thumb" + "github.com/cloudreve/Cloudreve/v4/pkg/util" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" + "io" "net/http" "net/url" "os" "path/filepath" "strings" +) - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" +const ( + twoFaEnableSessionKey = "2fa_init_" ) -// SettingService 通用设置服务 -type SettingService struct { -} +// Init2FA 初始化二步验证 +func Init2FA(c *gin.Context) (string, error) { + dep := dependency.FromContext(c) + user := inventory.UserFromContext(c) -// SettingListService 通用设置列表服务 -type SettingListService struct { - Page int `form:"page" binding:"required,min=1"` -} + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "Cloudreve", + AccountName: user.Email, + }) + if err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to generate TOTP secret", err) + } -// AvatarService 头像服务 -type AvatarService struct { - Size string `uri:"size" binding:"required,eq=l|eq=m|eq=s"` -} + if err := dep.KV().Set(fmt.Sprintf("%s%d", twoFaEnableSessionKey, user.ID), key.Secret(), 600); err != nil { + return "", serializer.NewError(serializer.CodeInternalSetting, "Failed to store TOTP session", err) + } -// SettingUpdateService 设定更改服务 -type SettingUpdateService struct { - Option string `uri:"option" binding:"required,eq=nick|eq=theme|eq=homepage|eq=vip|eq=qq|eq=policy|eq=password|eq=2fa|eq=authn"` + return key.Secret(), nil } -// OptionsChangeHandler 属性更改接口 -type OptionsChangeHandler interface { - Update(*gin.Context, *model.User) serializer.Response -} +type ( + // AvatarService Service to get avatar + GetAvatarService struct { + NoCache bool `form:"nocache"` + } + GetAvatarServiceParamsCtx struct{} +) -// ChangerNick 昵称更改服务 -type ChangerNick struct { - Nick string `json:"nick" binding:"required,min=1,max=255"` -} +const ( + GravatarAvatar = "gravatar" + FileAvatar = "file" +) -// PolicyChange 更改存储策略 -type PolicyChange struct { - ID string `json:"id" binding:"required"` -} +// Get 获取用户头像 +func (service *GetAvatarService) Get(c *gin.Context) error { + dep := dependency.FromContext(c) + settings := dep.SettingProvider() + // 查找目标用户 + uid := hashid.FromContext(c) + userClient := dep.UserClient() + user, err := userClient.GetByID(c, uid) -// HomePage 更改个人主页开关 -type HomePage struct { - Enabled bool `json:"status"` -} + if err != nil { + return serializer.NewError(serializer.CodeUserNotFound, "", err) + } -// PasswordChange 更改密码 -type PasswordChange struct { - Old string `json:"old" binding:"required,min=4,max=64"` - New string `json:"new" binding:"required,min=4,max=64"` -} + if !service.NoCache { + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", settings.PublicResourceMaxAge(c))) + } -// Enable2FA 开启二步验证 -type Enable2FA struct { - Code string `json:"code" binding:"required"` -} + // 未设定头像时,返回404错误 + if user.Avatar == "" { + c.Status(404) + return nil + } -// DeleteWebAuthn 删除WebAuthn凭证 -type DeleteWebAuthn struct { - ID string `json:"id" binding:"required"` -} + avatarSettings := settings.Avatar(c) -// ThemeChose 主题选择 -type ThemeChose struct { - Theme string `json:"theme" binding:"required,hexcolor|rgb|rgba|hsl"` -} + // Gravatar 头像重定向 + if user.Avatar == GravatarAvatar { + gravatarRoot, err := url.Parse(avatarSettings.Gravatar) + if err != nil { + return serializer.NewError(serializer.CodeInternalSetting, "Failed to parse Gravatar server", err) + } + email_lowered := strings.ToLower(user.Email) + has := md5.Sum([]byte(email_lowered)) + avatar, _ := url.Parse(fmt.Sprintf("/avatar/%x?d=mm&s=200", has)) -// Update 更新主题设定 -func (service *ThemeChose) Update(c *gin.Context, user *model.User) serializer.Response { - user.OptionsSerialized.PreferredTheme = service.Theme - if err := user.UpdateOptions(); err != nil { - return serializer.DBErr("Failed to update user preferences", err) + c.Redirect(http.StatusFound, gravatarRoot.ResolveReference(avatar).String()) + return nil } - return serializer.Response{} -} + // 本地文件头像 + if user.Avatar == FileAvatar { + avatarRoot := util.DataPath(avatarSettings.Path) -// Update 删除凭证 -func (service *DeleteWebAuthn) Update(c *gin.Context, user *model.User) serializer.Response { - user.RemoveAuthn(service.ID) - return serializer.Response{} + avatar, err := os.Open(filepath.Join(avatarRoot, fmt.Sprintf("avatar_%d.png", user.ID))) + if err != nil { + dep.Logger().Warning("Failed to open avatar file", err) + c.Status(404) + } + defer avatar.Close() + + http.ServeContent(c.Writer, c.Request, "avatar.png", user.UpdatedAt, avatar) + return nil + } + + c.Status(404) + return nil } -// Update 更改二步验证设定 -func (service *Enable2FA) Update(c *gin.Context, user *model.User) serializer.Response { - if user.TwoFactor == "" { - // 开启2FA - secret, ok := util.GetSession(c, "2fa_init").(string) - if !ok { - return serializer.Err(serializer.CodeInternalSetting, "You have not initiated 2FA session", nil) - } +// Settings 获取用户设定 +func GetUserSettings(c *gin.Context) (*UserSettings, error) { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + userClient := dep.UserClient() + passkeys, err := userClient.ListPasskeys(c, u.ID) + if err != nil { + return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user passkey", err) + } - if !totp.Validate(service.Code, secret) { - return serializer.ParamErr("Incorrect 2FA code", nil) - } + return BuildUserSettings(u, passkeys, dep.UAParser()), nil + + // 用户组有效期 + + //return serializer.Response{ + // Data: map[string]interface{}{ + // "uid": user.ID, + // "qq": user.OpenID != "", + // "homepage": !user.OptionsSerialized.ProfileOff, + // "two_factor": user.TwoFactor != "", + // "prefer_theme": user.OptionsSerialized.PreferredTheme, + // "themes": model.GetSettingByName("themes"), + // "group_expires": groupExpires, + // "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()), + // }, + //} +} - if err := user.Update(map[string]interface{}{"two_factor": secret}); err != nil { - return serializer.DBErr("Failed to update user preferences", err) - } +func UpdateUserAvatar(c *gin.Context) error { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + settings := dep.SettingProvider() - } else { - // 关闭2FA - if !totp.Validate(service.Code, user.TwoFactor) { - return serializer.ParamErr("Incorrect 2FA code", nil) - } + avatarSettings := settings.AvatarProcess(c) + if c.Request.ContentLength == -1 || c.Request.ContentLength > avatarSettings.MaxFileSize { + request.BlackHole(c.Request.Body) + return serializer.NewError(serializer.CodeFileTooLarge, "", nil) + } - if err := user.Update(map[string]interface{}{"two_factor": ""}); err != nil { - return serializer.DBErr("Failed to update user preferences", err) + if c.Request.ContentLength == 0 { + // Use Gravatar for empty body + if _, err := dep.UserClient().UpdateAvatar(c, u, GravatarAvatar); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user avatar", err) } + + return nil } - return serializer.Response{} + return updateAvatarFile(c, u, c.GetHeader("Content-Type"), c.Request.Body, avatarSettings) } -// Init2FA 初始化二步验证 -func (service *SettingService) Init2FA(c *gin.Context, user *model.User) serializer.Response { - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: "Cloudreve", - AccountName: user.Email, - }) +func updateAvatarFile(ctx context.Context, u *ent.User, contentType string, file io.Reader, avatarSettings *setting.AvatarProcess) error { + dep := dependency.FromContext(ctx) + // Detect ext from content type + ext := "png" + switch contentType { + case "image/jpeg", "image/jpg": + ext = "jpg" + case "image/gif": + ext = "gif" + } + avatar, err := thumb.NewThumbFromFile(file, ext) if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to generate TOTP secret", err) + return serializer.NewError(serializer.CodeParamErr, "Invalid image", err) } - util.SetSession(c, map[string]interface{}{"2fa_init": key.Secret()}) - return serializer.Response{Data: key.Secret()} -} + // Resize and save avatar + avatar.CreateAvatar(avatarSettings.MaxWidth) + avatarRoot := util.DataPath(avatarSettings.Path) + f, err := util.CreatNestedFile(filepath.Join(avatarRoot, fmt.Sprintf("avatar_%d.png", u.ID))) + if err != nil { + return serializer.NewError(serializer.CodeIOFailed, "Failed to create avatar file", err) + } -// Update 更改密码 -func (service *PasswordChange) Update(c *gin.Context, user *model.User) serializer.Response { - // 验证老密码 - if ok, _ := user.CheckPassword(service.Old); !ok { - return serializer.Err(serializer.CodeIncorrectPassword, "", nil) + defer f.Close() + if err := avatar.Save(f, &setting.ThumbEncode{ + Quality: 100, + Format: "png", + }); err != nil { + return serializer.NewError(serializer.CodeIOFailed, "Failed to save avatar file", err) } - // 更改为新密码 - user.SetPassword(service.New) - if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil { - return serializer.DBErr("Failed to update password", err) + if _, err := dep.UserClient().UpdateAvatar(ctx, u, FileAvatar); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user avatar", err) } - return serializer.Response{} + return nil } -// Update 切换个人主页开关 -func (service *HomePage) Update(c *gin.Context, user *model.User) serializer.Response { - user.OptionsSerialized.ProfileOff = !service.Enabled - if err := user.UpdateOptions(); err != nil { - return serializer.DBErr("Failed to update user preferences", err) +type ( + PatchUserSetting struct { + Nick *string `json:"nick" binding:"omitempty,min=1,max=255"` + Language *string `json:"language" binding:"omitempty,min=1,max=255"` + PreferredTheme *string `json:"preferred_theme" binding:"omitempty,hexcolor|rgb|rgba|hsl"` + VersionRetentionEnabled *bool `json:"version_retention_enabled" binding:"omitempty"` + VersionRetentionExt *[]string `json:"version_retention_ext" binding:"omitempty"` + VersionRetentionMax *int `json:"version_retention_max" binding:"omitempty,min=0"` + CurrentPassword *string `json:"current_password" binding:"omitempty,min=4,max=64"` + NewPassword *string `json:"new_password" binding:"omitempty,min=6,max=64"` + TwoFAEnabled *bool `json:"two_fa_enabled" binding:"omitempty"` + TwoFACode *string `json:"two_fa_code" binding:"omitempty"` } + PatchUserSettingParamsCtx struct{} +) - return serializer.Response{} -} +func (s *PatchUserSetting) Patch(c *gin.Context) error { + dep := dependency.FromContext(c) + u := inventory.UserFromContext(c) + userClient := dep.UserClient() + saveSetting := false -// Update 更改昵称 -func (service *ChangerNick) Update(c *gin.Context, user *model.User) serializer.Response { - if err := user.Update(map[string]interface{}{"nick": service.Nick}); err != nil { - return serializer.DBErr("Failed to update user", err) + if s.Nick != nil { + if _, err := userClient.UpdateNickname(c, u, *s.Nick); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user nick", err) + } } - return serializer.Response{} -} + if s.Language != nil { + u.Settings.Language = *s.Language + saveSetting = true + } -// Get 获取用户头像 -func (service *AvatarService) Get(c *gin.Context) serializer.Response { - // 查找目标用户 - uid, _ := c.Get("object_id") - user, err := model.GetActiveUserByID(uid.(uint)) - if err != nil { - return serializer.Err(serializer.CodeUserNotFound, "", err) + if s.PreferredTheme != nil { + u.Settings.PreferredTheme = *s.PreferredTheme + saveSetting = true } - // 未设定头像时,返回404错误 - if user.Avatar == "" { - c.Status(404) - return serializer.Response{} + if s.VersionRetentionEnabled != nil { + u.Settings.VersionRetention = *s.VersionRetentionEnabled + saveSetting = true } - // 获取头像设置 - sizes := map[string]string{ - "s": model.GetSettingByName("avatar_size_s"), - "m": model.GetSettingByName("avatar_size_m"), - "l": model.GetSettingByName("avatar_size_l"), + if s.VersionRetentionExt != nil { + u.Settings.VersionRetentionExt = *s.VersionRetentionExt + saveSetting = true } - // Gravatar 头像重定向 - if user.Avatar == "gravatar" { - server := model.GetSettingByName("gravatar_server") - gravatarRoot, err := url.Parse(server) - if err != nil { - return serializer.Err(serializer.CodeInternalSetting, "Failed to parse Gravatar server", err) + if s.VersionRetentionMax != nil { + u.Settings.VersionRetentionMax = *s.VersionRetentionMax + saveSetting = true + } + + if s.CurrentPassword != nil && s.NewPassword != nil { + if err := inventory.CheckPassword(u, *s.CurrentPassword); err != nil { + return serializer.NewError(serializer.CodeIncorrectPassword, "Incorrect password", err) } - email_lowered := strings.ToLower(user.Email) - has := md5.Sum([]byte(email_lowered)) - avatar, _ := url.Parse(fmt.Sprintf("/avatar/%x?d=mm&s=%s", has, sizes[service.Size])) - return serializer.Response{ - Code: -301, - Data: gravatarRoot.ResolveReference(avatar).String(), + if _, err := userClient.UpdatePassword(c, u, *s.NewPassword); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user password", err) } } - // 本地文件头像 - if user.Avatar == "file" { - avatarRoot := util.RelativePath(model.GetSettingByName("avatar_path")) - sizeToInt := map[string]string{ - "s": "0", - "m": "1", - "l": "2", - } + if s.TwoFAEnabled != nil { + if *s.TwoFAEnabled { + kv := dep.KV() + secret, ok := kv.Get(fmt.Sprintf("%s%d", twoFaEnableSessionKey, u.ID)) + if !ok { + return serializer.NewError(serializer.CodeInternalSetting, "You have not initiated 2FA session", nil) + } - avatar, err := os.Open(filepath.Join(avatarRoot, fmt.Sprintf("avatar_%d_%s.png", user.ID, sizeToInt[service.Size]))) - if err != nil { - c.Status(404) - return serializer.Response{} - } - defer avatar.Close() + if !totp.Validate(*s.TwoFACode, secret.(string)) { + return serializer.NewError(serializer.Code2FACodeErr, "Incorrect 2FA code", nil) + } - http.ServeContent(c.Writer, c.Request, "avatar.png", user.UpdatedAt, avatar) - return serializer.Response{} - } + if _, err := userClient.UpdateTwoFASecret(c, u, secret.(string)); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user 2FA", err) + } - c.Status(404) - return serializer.Response{} -} + } else { + if !totp.Validate(*s.TwoFACode, u.TwoFactorSecret) { + return serializer.NewError(serializer.Code2FACodeErr, "Incorrect 2FA code", nil) + } -// ListTasks 列出任务 -func (service *SettingListService) ListTasks(c *gin.Context, user *model.User) serializer.Response { - tasks, total := model.ListTasks(user.ID, service.Page, 10, "updated_at desc") - return serializer.BuildTaskList(tasks, total) -} + if _, err := userClient.UpdateTwoFASecret(c, u, ""); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user 2FA", err) + } -// Settings 获取用户设定 -func (service *SettingService) Settings(c *gin.Context, user *model.User) serializer.Response { - return serializer.Response{ - Data: map[string]interface{}{ - "uid": user.ID, - "homepage": !user.OptionsSerialized.ProfileOff, - "two_factor": user.TwoFactor != "", - "prefer_theme": user.OptionsSerialized.PreferredTheme, - "themes": model.GetSettingByName("themes"), - "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()), - }, + } + } + + if saveSetting { + if err := userClient.SaveSettings(c, u); err != nil { + return serializer.NewError(serializer.CodeDBError, "Failed to update user settings", err) + } } + + return nil }