enhance(download): Use just-in-time host in download URl, instead of SiteURL in site settings

pull/1741/head
Aaron Liu 2 years ago
parent 4c834e75fa
commit 4aafe1dc7a

@ -150,14 +150,7 @@ func (handler Driver) CORS() error {
// Get 获取文件 // Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -267,14 +260,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
// Driver Google Drive 适配器 // Driver Google Drive 适配器
@ -45,7 +44,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
panic("implement me") panic("implement me")
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
} }

@ -7,7 +7,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
var ( var (
@ -37,7 +36,7 @@ type Handler interface {
// 获取外链/下载地址, // 获取外链/下载地址,
// url - 站点本身地址, // url - 站点本身地址,
// isDownload - 是否直接下载 // isDownload - 是否直接下载
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名 // Token 获取有效期为ttl的上传凭证和签名
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)

@ -219,26 +219,20 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File) file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok { if !ok {
return "", errors.New("failed to read file model context") return "", errors.New("failed to read file model context")
} }
var baseURL *url.URL
// 是否启用了CDN // 是否启用了CDN
if handler.Policy.BaseURL != "" { if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL) cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil { if err != nil {
return "", err return "", err
} }
baseURL = *cdnURL baseURL = cdnURL
} }
var ( var (
@ -272,7 +266,11 @@ func (handler Driver) Source(
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err) return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err)
} }
finalURL := baseURL.ResolveReference(signedURI).String() finalURL := signedURI.String()
if baseURL != nil {
finalURL = baseURL.ResolveReference(signedURI).String()
}
return finalURL, nil return finalURL, nil
} }

@ -91,7 +91,6 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
downloadURL, err := handler.Source( downloadURL, err := handler.Source(
ctx, ctx,
path, path,
url.URL{},
60, 60,
false, false,
0, 0,
@ -164,7 +163,6 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
func (handler Driver) Source( func (handler Driver) Source(
ctx context.Context, ctx context.Context,
path string, path string,
baseURL url.URL,
ttl int64, ttl int64,
isDownload bool, isDownload bool,
speed int, speed int,

@ -9,7 +9,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -106,7 +105,7 @@ func TestDriver_Source(t *testing.T) {
// 失败 // 失败
{ {
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.Error(err) asserts.Error(err)
asserts.Empty(res) asserts.Empty(res)
} }
@ -116,7 +115,7 @@ func TestDriver_Source(t *testing.T) {
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
cache.Set("onedrive_source_0_123.jpg", "res", 1) cache.Set("onedrive_source_0_123.jpg", "res", 1)
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("res", res) asserts.Equal("res", res)
@ -131,7 +130,7 @@ func TestDriver_Source(t *testing.T) {
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0) cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(ctx, "123.jpg", 1, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("res", res) asserts.Equal("res", res)
@ -156,7 +155,7 @@ func TestDriver_Source(t *testing.T) {
}) })
handler.Client.Request = clientMock handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1" handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0) res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("123321", res) asserts.Equal("123321", res)
} }

@ -194,14 +194,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false) ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false)
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,14 +332,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 初始化客户端 // 初始化客户端
usePublicEndpoint := true usePublicEndpoint := true
if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok { if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok {

@ -119,14 +119,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano()) path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano())
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -264,14 +257,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

@ -124,7 +124,7 @@ func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser,
} }
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source(ctx, path, url.URL{}, 0, true, speedLimit) downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,14 +233,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "file" fileName := "file"
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

@ -9,7 +9,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"testing" "testing"
@ -51,7 +50,7 @@ func TestHandler_Source(t *testing.T) {
AuthInstance: auth.HMACAuth{}, AuthInstance: auth.HMACAuth{},
} }
ctx := context.Background() ctx := context.Background()
res, err := handler.Source(ctx, "", url.URL{}, 0, true, 0) res, err := handler.Source(ctx, "", 0, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.NotEmpty(res) asserts.NotEmpty(res)
} }
@ -66,7 +65,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/download/0") asserts.Contains(res, "api/v3/slave/download/0")
} }
@ -81,7 +80,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/download/0") asserts.Contains(res, "api/v3/slave/download/0")
asserts.Contains(res, "https://cqu.edu.cn") asserts.Contains(res, "https://cqu.edu.cn")
@ -97,7 +96,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) res, err := handler.Source(ctx, "", 10, true, 0)
asserts.Error(err) asserts.Error(err)
asserts.Empty(res) asserts.Empty(res)
} }
@ -112,7 +111,7 @@ func TestHandler_Source(t *testing.T) {
SourceName: "1.txt", SourceName: "1.txt",
} }
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
res, err := handler.Source(ctx, "", url.URL{}, 10, false, 0) res, err := handler.Source(ctx, "", 10, false, 0)
asserts.NoError(err) asserts.NoError(err)
asserts.Contains(res, "api/v3/slave/source/0") asserts.Contains(res, "api/v3/slave/source/0")
} }

@ -164,14 +164,7 @@ func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([
// Get 获取文件 // Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -270,14 +263,7 @@ func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.C
} }
// Source 获取外链URL // Source 获取外链URL
func (handler *Driver) Source( func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""

@ -8,7 +8,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
) )
// Driver 影子存储策略,用于在从机端上传文件 // Driver 影子存储策略,用于在从机端上传文件
@ -43,7 +42,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented return "", ErrNotImplemented
} }

@ -106,7 +106,7 @@ func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.Content
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) { func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented return "", ErrNotImplemented
} }

@ -107,14 +107,7 @@ func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]
// Get 获取文件 // Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址 // 获取文件源地址
downloadURL, err := handler.Source( downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
path,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -243,14 +236,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
thumbParam := fmt.Sprintf("!/fwfh/%dx%d", thumbSize[0], thumbSize[1]) thumbParam := fmt.Sprintf("!/fwfh/%dx%d", thumbSize[0], thumbSize[1])
thumbURL, err := handler.Source( thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
file.SourceName+thumbParam,
url.URL{},
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -262,14 +248,7 @@ func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.Co
} }
// Source 获取外链URL // Source 获取外链URL
func (handler Driver) Source( func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
ctx context.Context,
path string,
baseURL url.URL,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
// 尝试从上下文获取文件名 // 尝试从上下文获取文件名
fileName := "" fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {

@ -300,8 +300,7 @@ func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64,
// 签名最终URL // 签名最终URL
// 生成外链地址 // 生成外链地址
siteURL := model.GetSiteURL() source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit)
source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, *siteURL, ttl, isDownload, fs.User.Group.SpeedLimit)
if err != nil { if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err) return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
} }

@ -57,14 +57,7 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR
res = &response.ContentResponse{ res = &response.ContentResponse{
Redirect: true, Redirect: true,
} }
res.URL, err = fs.Handler.Source( res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
ctx,
file.ThumbFile(),
*model.GetSiteURL(),
int64(model.GetIntSetting("preview_timeout", 60)),
false,
0,
)
} else { } else {
// if not exist, generate and upload the sidecar thumb. // if not exist, generate and upload the sidecar thumb.
if err = fs.generateThumbnail(ctx, &file); err == nil { if err = fs.generateThumbnail(ctx, &file); err == nil {

Loading…
Cancel
Save