Feat: download file and get file downloading session

pull/247/head
HFO4 5 years ago
parent ad02a659a6
commit f262caf1f5

@ -97,7 +97,8 @@ solid #e9e9e9;"bgcolor="#fff"><tbody><tr style="font-family: 'Helvetica Neue',He
{Name: "ban_time", Value: `10`, Type: "storage_policy"},
{Name: "maxEditSize", Value: `100000`, Type: "file_edit"},
{Name: "oss_timeout", Value: `3600`, Type: "timeout"},
{Name: "local_archive_timeout", Value: `30`, Type: "timeout"},
{Name: "archive_timeout", Value: `30`, Type: "timeout"},
{Name: "download_timeout", Value: `30`, Type: "timeout"},
{Name: "allowdVisitorDownload", Value: `false`, Type: "share"},
{Name: "login_captcha", Value: `0`, Type: "login"},
{Name: "qq_login", Value: `0`, Type: "login"},

@ -8,6 +8,7 @@ import (
"github.com/HFO4/cloudreve/pkg/util"
"github.com/juju/ratelimit"
"io"
"strconv"
)
/* ============
@ -105,8 +106,8 @@ func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeke
return nil, ErrObjectNotExist
}
fs.FileTarget = []model.File{*file}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, file)
}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 将当前存储策略重设为文件使用的
fs.Policy = fs.FileTarget[0].GetPolicy()
@ -173,6 +174,45 @@ func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File)
return policyGroup
}
// GetDownloadURL 创建文件下载链接
func (fs *FileSystem) GetDownloadURL(ctx context.Context, path string) (string, error) {
// 找到文件
if len(fs.FileTarget) == 0 {
exist, file := fs.IsFileExist(path)
if !exist {
return "", ErrObjectNotExist
}
fs.FileTarget = []model.File{*file}
}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 将当前存储策略重设为文件使用的
fs.Policy = fs.FileTarget[0].GetPolicy()
err := fs.dispatchHandler()
if err != nil {
return "", err
}
// 生成下載地址
siteURL := model.GetSiteURL()
ttl, err := strconv.ParseInt(model.GetSettingByName("download_timeout"), 10, 64)
if err != nil {
return "", serializer.NewError(serializer.CodeInternalSetting, "无法获取下载地址有效期", err)
}
source, err := fs.Handler.GetDownloadURL(
ctx,
fs.FileTarget[0].SourceName,
*siteURL,
ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "无法获取下载地址", err)
}
return source, nil
}
// GetSource 获取可直接访问文件的外链地址
func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) {
// 查找文件记录

@ -32,6 +32,8 @@ type Handler interface {
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链地址url
Source(ctx context.Context, path string, url url.URL, expires int64) (string, error)
//获取下载地址
GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error)
}
// FileSystem 管理文件的文件系统

@ -6,6 +6,7 @@ import (
"fmt"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/auth"
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/response"
@ -15,6 +16,7 @@ import (
"net/url"
"os"
"path/filepath"
"time"
)
// Handler 本地策略适配器
@ -127,3 +129,29 @@ func (handler Handler) Source(ctx context.Context, path string, url url.URL, exp
finalURL := url.ResolveReference(signedURI).String()
return finalURL, nil
}
func (handler Handler) GetDownloadURL(ctx context.Context, path string, url url.URL, ttl int64) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return "", errors.New("无法获取文件记录上下文")
}
// 创建下载会话,将文件信息写入缓存
downloadSessionID := util.RandStringRunes(16)
err := cache.Set("download_"+downloadSessionID, file, int(ttl))
if err != nil {
return "", serializer.NewError(serializer.CodeCacheOperation, "无法创建下載会话", err)
}
// 签名生成文件记录
signedURI, err := auth.SignURI(
fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID),
time.Now().Unix()+ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "无法对URL进行签名", err)
}
finalURL := url.ResolveReference(signedURI).String()
return finalURL, nil
}

@ -46,6 +46,10 @@ func (m FileHeaderMock) Source(ctx context.Context, path string, url url.URL, ex
args := m.Called(ctx, path, url, expires)
return args.Get(0).(string), args.Error(1)
}
func (m FileHeaderMock) GetDownloadURL(ctx context.Context, path string, url url.URL, expires int64) (string, error) {
args := m.Called(ctx, path, url, expires)
return args.Get(0).(string), args.Error(1)
}
func TestFileSystem_Upload(t *testing.T) {
asserts := assert.New(t)

@ -66,6 +66,10 @@ const (
CodeEncryptError = 50002
// CodeIOFailed IO操作失败
CodeIOFailed = 50004
// CodeInternalSetting 内部设置参数错误
CodeInternalSetting = 50005
// CodeCacheOperation 缓存操作失败
CodeCacheOperation = 50006
//CodeParamErr 各种奇奇怪怪的参数错误
CodeParamErr = 40001
// CodeNotSet 未定错误后续尝试从error中获取

@ -22,9 +22,9 @@ func DownloadArchive(c *gin.Context) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var service explorer.ArchiveDownloadService
var service explorer.DownloadService
if err := c.ShouldBindUri(&service); err == nil {
res := service.Download(ctx, c)
res := service.DownloadArchived(ctx, c)
if res.Code != 0 {
c.JSON(200, res)
}
@ -137,13 +137,28 @@ func Thumb(c *gin.Context) {
}
// Download 文件下载
// CreateDownloadSession 创建文件下载会话
func CreateDownloadSession(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var service explorer.FileDownloadCreateService
if err := c.ShouldBindUri(&service); err == nil {
res := service.CreateDownloadSession(ctx, c)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// DownloadArchived 文件下载
func Download(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var service explorer.FileDownloadService
var service explorer.DownloadService
if err := c.ShouldBindUri(&service); err == nil {
res := service.Download(ctx, c)
if res.Code != 0 {

@ -75,6 +75,8 @@ func InitRouter() *gin.Engine {
file.GET("get/:id/:name", controllers.AnonymousGetContent)
// 下載已经打包好的文件
file.GET("archive/:id/archive.zip", controllers.DownloadArchive)
// 下载文件
file.GET("download/:id", controllers.Download)
}
}
@ -102,9 +104,9 @@ func InitRouter() *gin.Engine {
{
// 文件上传
file.POST("upload", controllers.FileUploadStream)
// 下载文件
file.GET("download/*path", controllers.Download)
// 下载文件
// 创建文件下载会话
file.PUT("download/*path", controllers.CreateDownloadSession)
// 获取缩略图
file.GET("thumb/:id", controllers.Thumb)
// 取得文件外链
file.GET("source/:id", controllers.GetSource)

@ -2,6 +2,7 @@ package explorer
import (
"context"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
@ -12,22 +13,24 @@ import (
"time"
)
// FileDownloadService 文件下载服务path为文件完整路径
type FileDownloadService struct {
// FileDownloadCreateService 文件下载会话创建服务path为文件完整路径
type FileDownloadCreateService struct {
Path string `uri:"path" binding:"required,min=1,max=65535"`
}
// FileAnonymousGetService 匿名(外链)获取文件服务
type FileAnonymousGetService struct {
ID uint `uri:"id" binding:"required,min=1"`
Name string `uri:"name" binding:"required"`
}
type ArchiveDownloadService struct {
// DownloadService 文件下載服务
type DownloadService struct {
ID string `uri:"id" binding:"required"`
}
// Download 下載已打包的多文件
func (service *ArchiveDownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response {
// DownloadArchived 下載已打包的多文件
func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
@ -98,23 +101,56 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
}
}
// CreateDownloadSession 创建下载会话获取下载URL
func (service *FileDownloadCreateService) CreateDownloadSession(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
// 获取下载地址
downloadURL, err := fs.GetDownloadURL(ctx, service.Path)
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
return serializer.Response{
Code: 0,
Data: downloadURL,
}
}
// Download 文件下载
func (service *FileDownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response {
func (service *DownloadService) Download(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
// 查找打包的临时文件
file, exist := cache.Get("download_" + service.ID)
if !exist {
return serializer.Err(404, "文件下载会话不存在", nil)
}
fs.FileTarget = []model.File{file.(model.File)}
// 开始处理下载
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
rs, err := fs.GetDownloadContent(ctx, service.Path)
rs, err := fs.GetDownloadContent(ctx, "")
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
// 设置文件名
c.Header("Content-Disposition", "attachment; filename=\""+fs.FileTarget[0].Name+"\"")
if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled {
// 清理资源,删除临时文件
_ = cache.Deletes([]string{service.ID}, "download_")
}
// 发送文件
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)

@ -61,7 +61,7 @@ func (service *ItemService) Archive(ctx context.Context, c *gin.Context) seriali
return serializer.Err(serializer.CodeNotSet, "无法解析站点URL", err)
}
zipID := util.RandStringRunes(16)
ttl, err := strconv.Atoi(model.GetSettingByName("local_archive_timeout"))
ttl, err := strconv.Atoi(model.GetSettingByName("archive_timeout"))
if err != nil {
ttl = 30
}

Loading…
Cancel
Save