Modify: add general ReaderCloserSeeker interface for handler GET method to return

pull/247/head
HFO4 5 years ago
parent f262caf1f5
commit 03dcd9a9e0

@ -4,6 +4,7 @@ import (
"context" "context"
model "github.com/HFO4/cloudreve/models" model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/response"
"github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/juju/ratelimit" "github.com/juju/ratelimit"
@ -18,7 +19,7 @@ import (
// 限速后的ReaderSeeker // 限速后的ReaderSeeker
type lrs struct { type lrs struct {
io.ReadSeeker response.RSCloser
r io.Reader r io.Reader
} }
@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) {
} }
// withSpeedLimit 给原有的ReadSeeker加上限速 // withSpeedLimit 给原有的ReadSeeker加上限速
func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker { func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
// 如果用户组有速度限制就返回限制流速的ReaderSeeker // 如果用户组有速度限制就返回限制流速的ReaderSeeker
if fs.User.Group.SpeedLimit != 0 { if fs.User.Group.SpeedLimit != 0 {
speed := fs.User.Group.SpeedLimit speed := fs.User.Group.SpeedLimit
@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
} }
// GetPhysicalFileContent 根据文件物理路径获取文件流 // GetPhysicalFileContent 根据文件物理路径获取文件流
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) { func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
// 重设上传策略 // 重设上传策略
fs.Policy = &model.Policy{Type: "local"} fs.Policy = &model.Policy{Type: "local"}
_ = fs.dispatchHandler() _ = fs.dispatchHandler()
@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (
} }
// GetDownloadContent 获取用于下载的文件流 // GetDownloadContent 获取用于下载的文件流
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) { func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) {
// 获取原始文件流 // 获取原始文件流
rs, err := fs.GetContent(ctx, path) rs, err := fs.GetContent(ctx, path)
if err != nil { if err != nil {
@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R
} }
// GetContent 获取文件内容path为虚拟路径 // GetContent 获取文件内容path为虚拟路径
func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) { func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) {
// 触发`下载前`钩子 // 触发`下载前`钩子
err := fs.Trigger(ctx, fs.BeforeFileDownload) err := fs.Trigger(ctx, fs.BeforeFileDownload)
if err != nil { if err != nil {

@ -27,7 +27,7 @@ type Handler interface {
// 删除一个或多个文件 // 删除一个或多个文件
Delete(ctx context.Context, files []string) ([]string, error) Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件 // 获取文件
Get(ctx context.Context, path string) (io.ReadSeeker, error) Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图 // 获取缩略图
Thumb(ctx context.Context, path string) (*response.ContentResponse, error) Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链地址url // 获取外链地址url

@ -25,7 +25,7 @@ type Handler struct {
} }
// Get 获取文件内容 // Get 获取文件内容
func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) { func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 打开文件 // 打开文件
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {

@ -10,3 +10,9 @@ type ContentResponse struct {
Content io.ReadSeeker Content io.ReadSeeker
URL string URL string
} }
// 存储策略适配器返回的文件流有些策略需要带有Closer
type RSCloser interface {
io.ReadSeeker
io.Closer
}

@ -22,9 +22,9 @@ type FileHeaderMock struct {
testMock.Mock testMock.Mock
} }
func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) { func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
args := m.Called(ctx, path) args := m.Called(ctx, path)
return args.Get(0).(io.ReadSeeker), args.Error(1) return args.Get(0).(response.RSCloser), args.Error(1)
} }
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {

@ -8,7 +8,6 @@ import (
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/serializer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io"
"net/http" "net/http"
"time" "time"
) )
@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
// 获取文件流 // 获取文件流
rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string)) rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
defer rs.Close()
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
c.Header("Content-Type", "application/zip") c.Header("Content-Type", "application/zip")
http.ServeContent(c.Writer, c.Request, "", time.Now(), rs) http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
err = fc.Close()
}
return serializer.Response{ return serializer.Response{
Code: 0, Code: 0,
} }
@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
// 获取文件流 // 获取文件流
rs, err := fs.GetDownloadContent(ctx, "") rs, err := fs.GetDownloadContent(ctx, "")
defer rs.Close()
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
// 发送文件 // 发送文件
http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs) http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
defer fc.Close()
}
return serializer.Response{ return serializer.Response{
Code: 0, Code: 0,
} }
@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
// 开始处理下载 // 开始处理下载
ctx = context.WithValue(ctx, fsctx.GinCtx, c) ctx = context.WithValue(ctx, fsctx.GinCtx, c)
rs, err := fs.GetDownloadContent(ctx, "") rs, err := fs.GetDownloadContent(ctx, "")
defer rs.Close()
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
// 发送文件 // 发送文件
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs) http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
defer fc.Close()
}
return serializer.Response{ return serializer.Response{
Code: 0, Code: 0,
} }

Loading…
Cancel
Save