diff --git a/go.mod b/go.mod index f54550b..f91fe45 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/gin-gonic/gin v1.4.0 github.com/go-ini/ini v1.50.0 github.com/jinzhu/gorm v1.9.11 + github.com/juju/ratelimit v1.0.1 github.com/mattn/go-colorable v0.1.4 // indirect github.com/mcuadros/go-version v0.0.0-20190830083331-035f6764e8d2 github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2 diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index c1be240..590af48 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -4,6 +4,7 @@ import ( "context" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/util" + "github.com/juju/ratelimit" "io" ) @@ -12,6 +13,23 @@ import ( ============ */ +// 限速后的ReaderSeeker +type lrs struct { + io.ReadSeeker + r io.Reader +} + +func (r lrs) Read(p []byte) (int, error) { + return r.r.Read(p) +} + +// withSpeedLimit 给原有的ReadSeeker加上限速 +func withSpeedLimit(rs io.ReadSeeker, speed int) io.ReadSeeker { + bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed)) + lrs := lrs{rs, ratelimit.Reader(rs, bucket)} + return lrs +} + // AddFile 新增文件记录 func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model.File, error) { file := ctx.Value(FileHeaderCtx).(FileHeader) @@ -35,6 +53,23 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model return &newFile, nil } +// GetDownloadContent 获取用于下载的文件流 +func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) { + // 获取原始文件流 + rs, err := fs.GetContent(ctx, path) + if err != nil { + return nil, err + } + + // 如果用户组有速度限制,就返回限制流速的ReaderSeeker + if fs.User.Group.SpeedLimit != 0 { + return withSpeedLimit(rs, fs.User.Group.SpeedLimit), nil + } + // 否则返回原始流 + return rs, nil + +} + // GetContent 获取文件内容,path为虚拟路径 // TODO:测试 func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) { diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index eb0ba18..7e3c5c6 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -77,6 +77,8 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe case <-ctx.Done(): // 客户端正常关闭,不执行操作 default: + // 客户端取消上传,删除临时文件 + util.Log().Debug("客户端取消上传") if fs.AfterUploadCanceled == nil { return } diff --git a/service/explorer/file.go b/service/explorer/file.go index bad8af1..2af51eb 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -24,7 +24,7 @@ func (service *FileDownloadService) Download(ctx context.Context, c *gin.Context // 开始处理下载 ctx = context.WithValue(ctx, filesystem.GinCtx, c) - rs, err := fs.GetContent(ctx, service.Path) + rs, err := fs.GetDownloadContent(ctx, service.Path) if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) }