Feat: cancel archive action / request with context

pull/247/head
HFO4 5 years ago
parent 1393659668
commit aeca161186

@ -3,10 +3,14 @@ package filesystem
import ( import (
"archive/zip" "archive/zip"
"context" "context"
"errors"
"fmt" "fmt"
model "github.com/HFO4/cloudreve/models" model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-gonic/gin"
"io" "io"
"os"
"path/filepath" "path/filepath"
"time" "time"
) )
@ -29,6 +33,11 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (
return "", ErrDBListObjects return "", ErrDBListObjects
} }
ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context)
if !ok {
return "", errors.New("无法获取请求上下文")
}
// 将顶级待处理对象的路径设为根路径 // 将顶级待处理对象的路径设为根路径
for i := 0; i < len(folders); i++ { for i := 0; i < len(folders); i++ {
folders[i].Position = "" folders[i].Position = ""
@ -53,19 +62,43 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (
zipWriter := zip.NewWriter(zipFile) zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close() defer zipWriter.Close()
ctx, _ = context.WithCancel(context.Background()) ctx = context.WithValue(ginCtx.Request.Context(), fsctx.UserCtx, *fs.User)
// ctx = context.WithValue(ctx, fsctx.UserCtx, *fs.User)
// 压缩各个目录及文件 // 压缩各个目录及文件
for i := 0; i < len(folders); i++ { for i := 0; i < len(folders); i++ {
select {
case <-ginCtx.Request.Context().Done():
// 取消压缩请求
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled
default:
fs.doCompress(ctx, nil, &folders[i], zipWriter, true) fs.doCompress(ctx, nil, &folders[i], zipWriter, true)
} }
}
for i := 0; i < len(files); i++ { for i := 0; i < len(files); i++ {
select {
case <-ginCtx.Request.Context().Done():
// 取消压缩请求
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled
default:
fs.doCompress(ctx, &files[i], nil, zipWriter, true) fs.doCompress(ctx, &files[i], nil, zipWriter, true)
} }
}
return zipFilePath, nil return zipFilePath, nil
} }
// cancelCompress 取消压缩进程
// TODO 测试
func (fs *FileSystem) cancelCompress(ctx context.Context, zipWriter *zip.Writer, file *os.File, path string) {
util.Log().Debug("客户端取消压缩请求")
zipWriter.Close()
file.Close()
_ = os.Remove(path)
}
func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) { func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
// 如果对象是文件 // 如果对象是文件
if file != nil { if file != nil {

@ -11,6 +11,7 @@ var (
ErrFileExtensionNotAllowed = errors.New("不允许上传此类型的文件") ErrFileExtensionNotAllowed = errors.New("不允许上传此类型的文件")
ErrInsufficientCapacity = errors.New("容量空间不足") ErrInsufficientCapacity = errors.New("容量空间不足")
ErrIllegalObjectName = errors.New("目标名称非法") ErrIllegalObjectName = errors.New("目标名称非法")
ErrClientCanceled = errors.New("客户端取消操作")
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "无法插入文件记录", nil) ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "无法插入文件记录", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "同名文件已存在", nil) ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "同名文件已存在", nil)
ErrFolderExisted = serializer.NewError(serializer.CodeObjectExist, "同名目录已存在", nil) ErrFolderExisted = serializer.NewError(serializer.CodeObjectExist, "同名目录已存在", nil)

@ -64,6 +64,7 @@ func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser,
"GET", "GET",
downloadURL, downloadURL,
nil, nil,
request.WithContext(ctx),
).GetRSCloser() ).GetRSCloser()
if err != nil { if err != nil {

@ -1,6 +1,7 @@
package request package request
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/auth"
@ -38,6 +39,7 @@ type options struct {
header http.Header header http.Header
sign auth.Auth sign auth.Auth
signTTL int64 signTTL int64
ctx context.Context
} }
type optionFunc func(*options) type optionFunc func(*options)
@ -60,6 +62,14 @@ func WithTimeout(t time.Duration) Option {
}) })
} }
// WithContext 设置请求上下文
// TODO 测试
func WithContext(c context.Context) Option {
return optionFunc(func(o *options) {
o.ctx = c
})
}
// WithCredential 对请求进行签名 // WithCredential 对请求进行签名
func WithCredential(instance auth.Auth, ttl int64) Option { func WithCredential(instance auth.Auth, ttl int64) Option {
return optionFunc(func(o *options) { return optionFunc(func(o *options) {
@ -87,7 +97,15 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
client := &http.Client{Timeout: options.timeout} client := &http.Client{Timeout: options.timeout}
// 创建请求 // 创建请求
req, err := http.NewRequest(method, target, body) var (
req *http.Request
err error
)
if options.ctx != nil {
req, err = http.NewRequestWithContext(options.ctx, method, target, body)
} else {
req, err = http.NewRequest(method, target, body)
}
if err != nil { if err != nil {
return Response{Err: err} return Response{Err: err}
} }

Loading…
Cancel
Save