diff --git a/main.go b/main.go index 9245afc..3f039e4 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/authn" + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/routers" "github.com/gin-gonic/gin" @@ -11,6 +12,7 @@ import ( func init() { conf.Init("conf/conf.ini") + cache.Init() model.Init() // Debug 关闭时,切换为生产模式 diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 04cbeba..7bd7868 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -8,7 +8,8 @@ import ( // Store 缓存存储器 var Store Driver -func init() { +// Init 初始化缓存 +func Init() { //Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") //return if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode { diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go index 677c5ea..864a487 100644 --- a/pkg/cache/driver_test.go +++ b/pkg/cache/driver_test.go @@ -42,3 +42,11 @@ func TestSetSettings(t *testing.T) { asserts.Equal("3", value1) asserts.Equal("4", value2) } + +func TestInit(t *testing.T) { + asserts := assert.New(t) + + asserts.NotPanics(func() { + Init() + }) +} diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 43a32f9..2a3582f 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -3,6 +3,7 @@ package filesystem import ( "context" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "github.com/juju/ratelimit" "io" @@ -150,3 +151,26 @@ func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) return policyGroup } + +// GetSource 获取可直接访问文件的外链地址 +func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) { + // 查找文件记录 + fileObject, err := model.GetFilesByIDs([]uint{fileID}, fs.User.ID) + if err != nil || len(fileObject) == 0 { + return "", ErrObjectNotExist.WithError(err) + } + + fs.FileTarget = []model.File{fileObject[0]} + // 将当前存储策略重设为文件使用的 + fs.Policy = fileObject[0].GetPolicy() + err = fs.dispatchHandler() + if err != nil { + return "", err + } + + // 检查存储策略是否可以获得外链 + if !fs.Policy.IsOriginLinkEnable { + return "", serializer.NewError(serializer.CodePolicyNotAllowed, "当前存储策略无法获得外链", nil) + } + return "", nil +} diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 4acbf9b..4d86913 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -15,6 +15,40 @@ import ( "strconv" ) +// GetSource 获取文件的外链地址 +func GetSource(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fs, err := filesystem.NewFileSystemFromContext(c) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) + return + } + + // 获取文件ID + fileID, err := strconv.ParseUint(c.Param("id"), 10, 32) + if err != nil { + c.JSON(200, serializer.ParamErr("无法解析文件ID", err)) + return + } + + sourceURL, err := fs.GetSource(ctx, uint(fileID)) + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeNotSet, err.Error(), err)) + return + } + + c.JSON(200, serializer.Response{ + Code: 0, + Data: struct { + URL string `json:"url"` + }{URL: sourceURL}, + }) + +} + // Thumb 获取文件缩略图 func Thumb(c *gin.Context) { // 创建上下文 diff --git a/routers/router.go b/routers/router.go index b384666..4ec7d1e 100644 --- a/routers/router.go +++ b/routers/router.go @@ -41,12 +41,17 @@ func InitRouter() *gin.Engine { { // 测试用路由 v3.GET("site/ping", controllers.Ping) - // 用户登录 - v3.POST("user/session", controllers.UserLogin) - // WebAuthn登陆初始化 - v3.GET("user/authn/:username", controllers.StartLoginAuthn) - // WebAuthn登陆 - v3.POST("user/authn/finish/:username", controllers.FinishLoginAuthn) + + // 不需要登录的用户相关路由 + { + // 用户登录 + v3.POST("user/session", controllers.UserLogin) + // WebAuthn登陆初始化 + v3.GET("user/authn/:username", controllers.StartLoginAuthn) + // WebAuthn登陆 + v3.POST("user/authn/finish/:username", controllers.FinishLoginAuthn) + } + // 验证码 v3.GET("captcha", controllers.Captcha) // 站点全局配置 @@ -80,6 +85,8 @@ func InitRouter() *gin.Engine { file.GET("download/*path", controllers.Download) // 下载文件 file.GET("thumb/:id", controllers.Thumb) + // 取得文件外链 + file.GET("source/:id", controllers.GetSource) } // 目录