From 02c93be3bc1a70c53300f0669f4974afc4652907 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Fri, 13 Dec 2019 15:10:44 +0800 Subject: [PATCH] Feat: download temporary archive file --- .gitignore | 3 ++- pkg/filesystem/file.go | 38 ++++++++++++++++++++-------- routers/controllers/file.go | 18 +++++++++++++- routers/router.go | 7 ++++-- service/explorer/file.go | 49 +++++++++++++++++++++++++++++++++++++ service/explorer/objects.go | 3 ++- 6 files changed, 103 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index a38771f..56465de 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ uploads/* version.lock # Config file -*.ini \ No newline at end of file +*.ini +/conf/conf.ini diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index e12a441..bee9d69 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -26,10 +26,17 @@ func (r lrs) Read(p []byte) (int, error) { } // withSpeedLimit 给原有的ReadSeeker加上限速 -func withSpeedLimit(rs io.ReadSeeker, speed int) io.ReadSeeker { - bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed)) - lrs := lrs{rs, ratelimit.Reader(rs, bucket)} - return lrs +func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker { + // 如果用户组有速度限制,就返回限制流速的ReaderSeeker + if fs.User.Group.SpeedLimit != 0 { + speed := fs.User.Group.SpeedLimit + bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed)) + lrs := lrs{rs, ratelimit.Reader(rs, bucket)} + return lrs + } + // 否则返回原始流 + return rs + } // AddFile 新增文件记录 @@ -54,6 +61,21 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model return &newFile, nil } +// GetPhysicalFileContent 根据文件物理路径获取文件流 +func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) { + // 重设上传策略 + fs.Policy = &model.Policy{Type: "local"} + _ = fs.dispatchHandler() + + // 获取文件流 + rs, err := fs.Handler.Get(ctx, path) + if err != nil { + return nil, err + } + + return fs.withSpeedLimit(rs), nil +} + // GetDownloadContent 获取用于下载的文件流 func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) { // 获取原始文件流 @@ -62,12 +84,8 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R return nil, err } - // 如果用户组有速度限制,就返回限制流速的ReaderSeeker - if fs.User.Group.SpeedLimit != 0 { - return withSpeedLimit(rs, fs.User.Group.SpeedLimit), nil - } - // 否则返回原始流 - return rs, nil + // 返回限速处理后的文件流 + return fs.withSpeedLimit(rs), nil } diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 344cfa3..b1ad703 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -17,7 +17,23 @@ import ( "strconv" ) -func ArchiveAndDownload(c *gin.Context) { +func DownloadArchive(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var service explorer.ArchiveDownloadService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Download(ctx, c) + if res.Code != 0 { + c.JSON(200, res) + } + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +func Archive(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/routers/router.go b/routers/router.go index 95d99f0..46219df 100644 --- a/routers/router.go +++ b/routers/router.go @@ -67,6 +67,7 @@ func InitRouter() *gin.Engine { { file := sign.Group("file") { + // 下載文件 file.GET("get/:id/:name", controllers.AnonymousGetContent) } } @@ -101,8 +102,10 @@ func InitRouter() *gin.Engine { file.GET("thumb/:id", controllers.Thumb) // 取得文件外链 file.GET("source/:id", controllers.GetSource) - // 测试用:压缩文件和目录并下載 - file.POST("archive", controllers.ArchiveAndDownload) + // 打包要下载的文件 + file.POST("archive", controllers.Archive) + // 下載已经打包好的文件 + file.Use(middleware.SignRequired()).GET("archive/:id", controllers.DownloadArchive) } // 目录 diff --git a/service/explorer/file.go b/service/explorer/file.go index 83decce..140538b 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -2,12 +2,16 @@ package explorer import ( "context" + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "io" "net/http" + "os" + "time" ) // FileDownloadService 文件下载服务,path为文件完整路径 @@ -20,6 +24,51 @@ type FileAnonymousGetService struct { Name string `uri:"name" binding:"required"` } +type ArchiveDownloadService struct { + ID string `uri:"id" binding:"required"` +} + +// Download 下載已打包的多文件 +func (service *ArchiveDownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewFileSystemFromContext(c) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + + // 查找打包的临时文件 + zipPath, exist := cache.Get("archive_" + service.ID) + if !exist { + return serializer.Err(404, "归档文件不存在", nil) + } + + // 获取文件流 + rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string)) + if err != nil { + return serializer.Err(serializer.CodeNotSet, err.Error(), err) + } + + c.Header("Content-Type", "application/zip") + http.ServeContent(c.Writer, c.Request, "archive.zip", time.Now(), rs) + + // 检查是否需要关闭文件 + if fc, ok := rs.(io.Closer); ok { + err = fc.Close() + } + + // 清理资源,删除临时文件 + _ = cache.Deletes([]string{service.ID}, "archive_") + err = os.Remove(zipPath.(string)) + if err != nil { + util.Log().Warning("无法删除临时文件 %s :%s", zipPath, err) + } + + return serializer.Response{ + Code: 0, + } + +} + // Download 签名的匿名文件下载 func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Context) serializer.Response { fs, err := filesystem.NewAnonymousFileSystem() diff --git a/service/explorer/objects.go b/service/explorer/objects.go index a7ac37b..b87bef0 100644 --- a/service/explorer/objects.go +++ b/service/explorer/objects.go @@ -12,6 +12,7 @@ import ( "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "net/url" + "time" ) // ItemMoveService 处理多文件/目录移动 @@ -61,7 +62,7 @@ func (service *ItemService) Archive(ctx context.Context, c *gin.Context) seriali zipID := util.RandStringRunes(16) signedURI, err := auth.SignURI( fmt.Sprintf("/api/v3/file/archive/%s", zipID), - 120, + time.Now().Unix()+120, ) finalURL := siteURL.ResolveReference(signedURI).String()