diff --git a/models/download_test.go b/models/download_test.go index d24531b..7be81d0 100644 --- a/models/download_test.go +++ b/models/download_test.go @@ -63,14 +63,6 @@ func TestDownload_AfterFind(t *testing.T) { asserts.Equal("", download.StatusInfo.Gid) } - // 关联任务 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "error"}).AddRow(1, "error")) - download := Download{TaskID: 1} - download.BeforeSave() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", download.Task.Error) - } } func TestDownload_Save(t *testing.T) { diff --git a/pkg/filesystem/driver/local/handller_test.go b/pkg/filesystem/driver/local/handller_test.go index 4033f31..8924b5c 100644 --- a/pkg/filesystem/driver/local/handller_test.go +++ b/pkg/filesystem/driver/local/handller_test.go @@ -120,7 +120,9 @@ func TestHandler_Thumb(t *testing.T) { func TestHandler_Source(t *testing.T) { asserts := assert.New(t) - handler := Driver{} + handler := Driver{ + Policy: &model.Policy{}, + } ctx := context.Background() auth.General = auth.HMACAuth{SecretKey: []byte("test")} @@ -150,6 +152,42 @@ func TestHandler_Source(t *testing.T) { asserts.Error(err) asserts.Empty(sourceURL) } + + // 设定了CDN + { + handler.Policy.BaseURL = "https://cqu.edu.cn" + file := model.File{ + Model: gorm.Model{ + ID: 1, + }, + Name: "test.jpg", + } + ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0, false, 0) + asserts.NoError(err) + asserts.NotEmpty(sourceURL) + asserts.Contains(sourceURL, "sign=") + asserts.Contains(sourceURL, "https://cqu.edu.cn") + } + + // 设定了CDN,解析失败 + { + handler.Policy.BaseURL = string(0x7f) + file := model.File{ + Model: gorm.Model{ + ID: 1, + }, + Name: "test.jpg", + } + ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) + baseURL, err := url.Parse("https://cloudreve.org") + asserts.NoError(err) + sourceURL, err := handler.Source(ctx, "", *baseURL, 0, false, 0) + asserts.Error(err) + asserts.Empty(sourceURL) + } } func TestHandler_GetDownloadURL(t *testing.T) { diff --git a/pkg/filesystem/driver/oss/handler_test.go b/pkg/filesystem/driver/oss/handler_test.go index f379e28..3a4b931 100644 --- a/pkg/filesystem/driver/oss/handler_test.go +++ b/pkg/filesystem/driver/oss/handler_test.go @@ -39,6 +39,23 @@ func TestDriver_InitOSSClient(t *testing.T) { } } +func TestDriver_CORS(t *testing.T) { + asserts := assert.New(t) + handler := Driver{ + Policy: &model.Policy{ + AccessKey: "ak", + SecretKey: "sk", + BucketName: "test", + Server: "test.com", + }, + } + + // 失败 + { + asserts.Error(handler.CORS()) + } +} + func TestDriver_Token(t *testing.T) { asserts := assert.New(t) handler := Driver{ @@ -149,7 +166,18 @@ func TestDriver_Source(t *testing.T) { asserts.NoError(err) query := resURL.Query() asserts.Empty(query.Get("Signature")) - asserts.Empty(query.Get("Expires")) + } + + // 正常 指定了CDN域名 + { + handler.Policy.BaseURL = "https://cqu.edu.cn" + res, err := handler.Source(context.Background(), "/123", url.URL{}, 10, false, 0) + asserts.NoError(err) + resURL, err := url.Parse(res) + asserts.NoError(err) + query := resURL.Query() + asserts.Empty(query.Get("Signature")) + asserts.Contains(resURL.String(), handler.Policy.BaseURL) } } diff --git a/pkg/filesystem/driver/remote/handler_test.go b/pkg/filesystem/driver/remote/handler_test.go index 70b7a32..6315674 100644 --- a/pkg/filesystem/driver/remote/handler_test.go +++ b/pkg/filesystem/driver/remote/handler_test.go @@ -84,6 +84,37 @@ func TestHandler_Source(t *testing.T) { asserts.Contains(res, "api/v3/slave/download/0") } + // 成功 自定义CDN + { + handler := Driver{ + Policy: &model.Policy{Server: "/", BaseURL: "https://cqu.edu.cn"}, + AuthInstance: auth.HMACAuth{}, + } + file := model.File{ + SourceName: "1.txt", + } + ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) + res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) + asserts.NoError(err) + asserts.Contains(res, "api/v3/slave/download/0") + asserts.Contains(res, "https://cqu.edu.cn") + } + + // 解析失败 自定义CDN + { + handler := Driver{ + Policy: &model.Policy{Server: "/", BaseURL: string(0x7f)}, + AuthInstance: auth.HMACAuth{}, + } + file := model.File{ + SourceName: "1.txt", + } + ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) + res, err := handler.Source(ctx, "", url.URL{}, 10, true, 0) + asserts.Error(err) + asserts.Empty(res) + } + // 成功 预览 { handler := Driver{ diff --git a/pkg/hashid/hash_test.go b/pkg/hashid/hash_test.go index 5471d9e..2f7472d 100644 --- a/pkg/hashid/hash_test.go +++ b/pkg/hashid/hash_test.go @@ -1,6 +1,7 @@ package hashid import ( + "github.com/HFO4/cloudreve/bootstrap/constant" "github.com/stretchr/testify/assert" "testing" ) @@ -52,6 +53,7 @@ func TestHashDecode(t *testing.T) { func TestDecodeHashID(t *testing.T) { asserts := assert.New(t) + constant.HashIDTable = []int{0, 1, 2, 3, 4, 5} // 成功 {