Feat: download file and get file downloading session

pull/247/head
HFO4 5 years ago
parent ad02a659a6
commit f262caf1f5

@ -97,7 +97,8 @@ solid #e9e9e9;"bgcolor="#fff"><tbody><tr style="font-family: 'Helvetica Neue',He
{Name: "ban_time", Value: `10`, Type: "storage_policy"}, {Name: "ban_time", Value: `10`, Type: "storage_policy"},
{Name: "maxEditSize", Value: `100000`, Type: "file_edit"}, {Name: "maxEditSize", Value: `100000`, Type: "file_edit"},
{Name: "oss_timeout", Value: `3600`, Type: "timeout"}, {Name: "oss_timeout", Value: `3600`, Type: "timeout"},
{Name: "local_archive_timeout", Value: `30`, Type: "timeout"}, {Name: "archive_timeout", Value: `30`, Type: "timeout"},
{Name: "download_timeout", Value: `30`, Type: "timeout"},
{Name: "allowdVisitorDownload", Value: `false`, Type: "share"}, {Name: "allowdVisitorDownload", Value: `false`, Type: "share"},
{Name: "login_captcha", Value: `0`, Type: "login"}, {Name: "login_captcha", Value: `0`, Type: "login"},
{Name: "qq_login", Value: `0`, Type: "login"}, {Name: "qq_login", Value: `0`, Type: "login"},

@ -8,6 +8,7 @@ import (
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/juju/ratelimit" "github.com/juju/ratelimit"
"io" "io"
"strconv"
) )
/* ============ /* ============
@ -105,8 +106,8 @@ func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeke
return nil, ErrObjectNotExist return nil, ErrObjectNotExist
} }
fs.FileTarget = []model.File{*file} fs.FileTarget = []model.File{*file}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, file)
} }
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 将当前存储策略重设为文件使用的 // 将当前存储策略重设为文件使用的
fs.Policy = fs.FileTarget[0].GetPolicy() fs.Policy = fs.FileTarget[0].GetPolicy()
@ -173,6 +174,45 @@ func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File)
return policyGroup return policyGroup
} }
// GetDownloadURL 创建文件下载链接
func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string) (string, error) {
// 找到文件
if len(fs.FileTarget) == 0 {
exist, file := fs.IsFileExist(path)
if !exist {
return "", ErrObjectNotExist
}
fs.FileTarget = []model.File{*file}
}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 将当前存储策略重设为文件使用的
fs.Policy = fs.FileTarget[0].GetPolicy()
err := fs.dispatchHandler()
if err != nil {
return "", err
}
// 生成下載地址
siteURL := model.GetSiteURL()
ttl, err := strconv.ParseInt(model.GetSettingByName("download_timeout"), 10, 64)
if err != nil {
return "", serializer.NewError(serializer.CodeInternalSetting, "无法获取下载地址有效期", err)
}
source, err := fs.Handler.GetDownloadURL(
ctx,
fs.FileTarget[0].SourceName,
*siteURL,
ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "无法获取下载地址", err)
}
return source, nil
}
// GetSource 获取可直接访问文件的外链地址 // GetSource 获取可直接访问文件的外链地址
func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) { func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) {
// 查找文件记录 // 查找文件记录

@ -32,6 +32,8 @@ type Handler interface {
Thumb(ctx context.Context, path string) (*response.ContentResponse, error) Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链地址url // 获取外链地址url
Source(ctx context.Context, path string, url url.URL, expires int64) (string, error) Source(ctx context.Context, path string, url url.URL, expires int64) (string, error)
//获取下载地址
GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error)
} }
// FileSystem 管理文件的文件系统 // FileSystem 管理文件的文件系统

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
model "github.com/HFO4/cloudreve/models" model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/auth"
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/response" "github.com/HFO4/cloudreve/pkg/filesystem/response"
@ -15,6 +16,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"time"
) )
// Handler 本地策略适配器 // Handler 本地策略适配器
@ -127,3 +129,29 @@ func (handler Handler) Source(ctx context.Context, path string, url url.URL, exp
finalURL := url.ResolveReference(signedURI).String() finalURL := url.ResolveReference(signedURI).String()
return finalURL, nil return finalURL, nil
} }
func (handler Handler) GetDownloadURL(ctx context.Context, path string, url url.URL, ttl int64) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return "", errors.New("无法获取文件记录上下文")
}
// 创建下载会话,将文件信息写入缓存
downloadSessionID := util.RandStringRunes(16)
err := cache.Set("download_"+downloadSessionID, file, int(ttl))
if err != nil {
return "", serializer.NewError(serializer.CodeCacheOperation, "无法创建下載会话", err)
}
// 签名生成文件记录
signedURI, err := auth.SignURI(
fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID),
time.Now().Unix()+ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "无法对URL进行签名", err)
}
finalURL := url.ResolveReference(signedURI).String()
return finalURL, nil
}

@ -46,6 +46,10 @@ func (m FileHeaderMock) Source(ctx context.Context, path string, url url.URL, ex
args := m.Called(ctx, path, url, expires) args := m.Called(ctx, path, url, expires)
return args.Get(0).(string), args.Error(1) return args.Get(0).(string), args.Error(1)
} }
func (m FileHeaderMock) GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error) {
args := m.Called(ctx, path, url, expires)
return args.Get(0).(string), args.Error(1)
}
func TestFileSystem_Upload(t *testing.T) { func TestFileSystem_Upload(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)

@ -66,6 +66,10 @@ const (
CodeEncryptError = 50002 CodeEncryptError = 50002
// CodeIOFailed IO操作失败 // CodeIOFailed IO操作失败
CodeIOFailed = 50004 CodeIOFailed = 50004
// CodeInternalSetting 内部设置参数错误
CodeInternalSetting = 50005
// CodeCacheOperation 缓存操作失败
CodeCacheOperation = 50006
//CodeParamErr 各种奇奇怪怪的参数错误 //CodeParamErr 各种奇奇怪怪的参数错误
CodeParamErr = 40001 CodeParamErr = 40001
// CodeNotSet 未定错误后续尝试从error中获取 // CodeNotSet 未定错误后续尝试从error中获取

@ -22,9 +22,9 @@ func DownloadArchive(c *gin.Context) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
var service explorer.ArchiveDownloadService var service explorer.DownloadService
if err := c.ShouldBindUri(&service); err == nil { if err := c.ShouldBindUri(&service); err == nil {
res := service.Download(ctx, c) res := service.DownloadArchived(ctx, c)
if res.Code != 0 { if res.Code != 0 {
c.JSON(200, res) c.JSON(200, res)
} }
@ -137,13 +137,28 @@ func Thumb(c *gin.Context) {
} }
// Download 文件下载 // CreateDownloadSession 创建文件下载会话
func CreateDownloadSession(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var service explorer.FileDownloadCreateService
if err := c.ShouldBindUri(&service); err == nil {
res := service.CreateDownloadSession(ctx, c)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// DownloadArchived 文件下载
func Download(c *gin.Context) { func Download(c *gin.Context) {
// 创建上下文 // 创建上下文
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
var service explorer.FileDownloadService var service explorer.DownloadService
if err := c.ShouldBindUri(&service); err == nil { if err := c.ShouldBindUri(&service); err == nil {
res := service.Download(ctx, c) res := service.Download(ctx, c)
if res.Code != 0 { if res.Code != 0 {

@ -75,6 +75,8 @@ func InitRouter() *gin.Engine {
file.GET("get/:id/:name", controllers.AnonymousGetContent) file.GET("get/:id/:name", controllers.AnonymousGetContent)
// 下載已经打包好的文件 // 下載已经打包好的文件
file.GET("archive/:id/archive.zip", controllers.DownloadArchive) file.GET("archive/:id/archive.zip", controllers.DownloadArchive)
// 下载文件
file.GET("download/:id", controllers.Download)
} }
} }
@ -102,9 +104,9 @@ func InitRouter() *gin.Engine {
{ {
// 文件上传 // 文件上传
file.POST("upload", controllers.FileUploadStream) file.POST("upload", controllers.FileUploadStream)
// 下载文件 // 创建文件下载会话
file.GET("download/*path", controllers.Download) file.PUT("download/*path", controllers.CreateDownloadSession)
// 下载文件 // 获取缩略图
file.GET("thumb/:id", controllers.Thumb) file.GET("thumb/:id", controllers.Thumb)
// 取得文件外链 // 取得文件外链
file.GET("source/:id", controllers.GetSource) file.GET("source/:id", controllers.GetSource)

@ -2,6 +2,7 @@ package explorer
import ( import (
"context" "context"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
@ -12,22 +13,24 @@ import (
"time" "time"
) )
// FileDownloadService 文件下载服务path为文件完整路径 // FileDownloadCreateService 文件下载会话创建服务path为文件完整路径
type FileDownloadService struct { type FileDownloadCreateService struct {
Path string `uri:"path" binding:"required,min=1,max=65535"` Path string `uri:"path" binding:"required,min=1,max=65535"`
} }
// FileAnonymousGetService 匿名(外链)获取文件服务
type FileAnonymousGetService struct { type FileAnonymousGetService struct {
ID uint `uri:"id" binding:"required,min=1"` ID uint `uri:"id" binding:"required,min=1"`
Name string `uri:"name" binding:"required"` Name string `uri:"name" binding:"required"`
} }
type ArchiveDownloadService struct { // DownloadService 文件下載服务
type DownloadService struct {
ID string `uri:"id" binding:"required"` ID string `uri:"id" binding:"required"`
} }
// Download 下載已打包的多文件 // DownloadArchived 下載已打包的多文件
func (service *ArchiveDownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response { func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c) fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil { if err != nil {
@ -98,23 +101,56 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
} }
} }
// CreateDownloadSession 创建下载会话获取下载URL
func (service *FileDownloadCreateService) CreateDownloadSession(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
// 获取下载地址
downloadURL, err := fs.GetDownloadURL(ctx, service.Path)
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
return serializer.Response{
Code: 0,
Data: downloadURL,
}
}
// Download 文件下载 // Download 文件下载
func (service *FileDownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response { func (service *DownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c) fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil { if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
} }
// 查找打包的临时文件
file, exist := cache.Get("download_" + service.ID)
if !exist {
return serializer.Err(404, "文件下载会话不存在", nil)
}
fs.FileTarget = []model.File{file.(model.File)}
// 开始处理下载 // 开始处理下载
ctx = context.WithValue(ctx, fsctx.GinCtx, c) ctx = context.WithValue(ctx, fsctx.GinCtx, c)
rs, err := fs.GetDownloadContent(ctx, service.Path) rs, err := fs.GetDownloadContent(ctx, "")
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
// 设置文件名 // 设置文件名
c.Header("Content-Disposition", "attachment; filename=\""+fs.FileTarget[0].Name+"\"") c.Header("Content-Disposition", "attachment; filename=\""+fs.FileTarget[0].Name+"\"")
if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled {
// 清理资源,删除临时文件
_ = cache.Deletes([]string{service.ID}, "download_")
}
// 发送文件 // 发送文件
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs) http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)

@ -61,7 +61,7 @@ func (service *ItemService) Archive(ctx context.Context, c *gin.Context) seriali
return serializer.Err(serializer.CodeNotSet, "无法解析站点URL", err) return serializer.Err(serializer.CodeNotSet, "无法解析站点URL", err)
} }
zipID := util.RandStringRunes(16) zipID := util.RandStringRunes(16)
ttl, err := strconv.Atoi(model.GetSettingByName("local_archive_timeout")) ttl, err := strconv.Atoi(model.GetSettingByName("archive_timeout"))
if err != nil { if err != nil {
ttl = 30 ttl = 30
} }

Loading…
Cancel
Save