diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 84b1ed7..990b046 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -54,6 +54,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( defer zipWriter.Close() ctx, _ = context.WithCancel(context.Background()) + // ctx = context.WithValue(ctx, fsctx.UserCtx, *fs.User) // 压缩各个目录及文件 for i := 0; i < len(folders); i++ { fs.doCompress(ctx, nil, &folders[i], zipWriter, true) diff --git a/pkg/filesystem/fsctx/context.go b/pkg/filesystem/fsctx/context.go index 43bda92..66a10bb 100644 --- a/pkg/filesystem/fsctx/context.go +++ b/pkg/filesystem/fsctx/context.go @@ -17,4 +17,6 @@ const ( HTTPCtx // UploadPolicyCtx 上传策略,一般为slave模式下使用 UploadPolicyCtx + // UserCtx 用户 + UserCtx ) diff --git a/pkg/filesystem/remote/handler.go b/pkg/filesystem/remote/handler.go index 98f3ee0..ded6f1a 100644 --- a/pkg/filesystem/remote/handler.go +++ b/pkg/filesystem/remote/handler.go @@ -45,9 +45,32 @@ func (handler Handler) getAPI(scope string) string { } // Get 获取文件内容 +// TODO 测试 func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) { + // 尝试获取速度限制 TODO 是否需要在这里限制? + speedLimit := 0 + if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { + speedLimit = user.Group.SpeedLimit + } - return nil, nil + // 获取文件源地址 + downloadURL, err := handler.Source(ctx, path, url.URL{}, 0, true, speedLimit) + if err != nil { + return nil, err + } + + // 获取文件数据流 + resp, err := handler.Client.Request( + "GET", + downloadURL, + nil, + ).GetRSCloser() + + if err != nil { + return nil, err + } + + return resp, nil } // Put 将文件流保存到指定目录 @@ -125,9 +148,10 @@ func (handler Handler) Source( isDownload bool, speed int, ) (string, error) { - file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) - if !ok { - return "", errors.New("无法获取文件记录上下文") + // 尝试从上下文获取文件名 + fileName := "file" + if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { + fileName = file.Name } serverURL, err := url.Parse(handler.Policy.Server) @@ -144,10 +168,10 @@ func (handler Handler) Source( } // 签名下载地址 - sourcePath := base64.RawURLEncoding.EncodeToString([]byte(file.SourceName)) + sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path)) signedURI, err = auth.SignURI( handler.AuthInstance, - fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, file.Name), + fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, fileName), ttl, ) diff --git a/pkg/filesystem/remote/handler_test.go b/pkg/filesystem/remote/handler_test.go index 0008da2..0ee7e45 100644 --- a/pkg/filesystem/remote/handler_test.go +++ b/pkg/filesystem/remote/handler_test.go @@ -58,12 +58,13 @@ func TestHandler_Source(t *testing.T) { // 无法获取上下文 { handler := Handler{ + Policy: &model.Policy{Server: "/"}, AuthInstance: auth.HMACAuth{}, } ctx := context.Background() res, err := handler.Source(ctx, "", url.URL{}, 0, true, 0) - asserts.Error(err) - asserts.Empty(res) + asserts.NoError(err) + asserts.NotEmpty(res) } // 成功 diff --git a/pkg/request/request.go b/pkg/request/request.go index f0949a8..b523103 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -1,8 +1,10 @@ package request import ( + "errors" "fmt" "github.com/HFO4/cloudreve/pkg/auth" + "github.com/HFO4/cloudreve/pkg/filesystem/response" "io" "io/ioutil" "net/http" @@ -120,3 +122,29 @@ func (resp Response) GetResponse(expectStatus int) (string, error) { } return string(respBody), err } + +type nopRSCloser struct { + body io.ReadCloser +} + +// GetRSCloser 返回带有空seeker的body reader +func (resp Response) GetRSCloser() (response.RSCloser, error) { + return nopRSCloser{ + body: resp.Response.Body, + }, resp.Err +} + +// Read 实现 nopRSCloser reader +func (instance nopRSCloser) Read(p []byte) (n int, err error) { + return instance.body.Read(p) +} + +// 实现 nopRSCloser closer +func (instance nopRSCloser) Close() error { + return instance.body.Close() +} + +// 实现 nopRSCloser seeker +func (instance nopRSCloser) Seek(offset int64, whence int) (int64, error) { + return 0, errors.New("未实现") +} diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 5eb8a10..f76e049 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -229,14 +229,14 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request, fs * } ctx := r.Context() - rs, err := fs.GetContent(ctx, reqPath) + + rs, err := fs.Preview(ctx, reqPath) if err != nil { if err == filesystem.ErrObjectNotExist { return http.StatusNotFound, err } return http.StatusInternalServerError, err } - defer rs.Close() etag, err := findETag(ctx, fs, h.LockSystem[fs.User.ID], reqPath, &fs.FileTarget[0]) if err != nil { @@ -244,8 +244,15 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request, fs * } w.Header().Set("ETag", etag) - // 获取文件内容 - http.ServeContent(w, r, reqPath, fs.FileTarget[0].UpdatedAt, rs) + if !rs.Redirect { + defer rs.Content.Close() + // 获取文件内容 + http.ServeContent(w, r, reqPath, fs.FileTarget[0].UpdatedAt, rs.Content) + return 0, nil + } + + http.Redirect(w, r, rs.URL, 301) + return 0, nil }