diff --git a/middleware/option_test.go b/middleware/option_test.go index a0af92d..7048e9b 100644 --- a/middleware/option_test.go +++ b/middleware/option_test.go @@ -1,6 +1,7 @@ package middleware import ( + "github.com/HFO4/cloudreve/bootstrap/constant" "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/hashid" "github.com/gin-gonic/gin" @@ -14,6 +15,7 @@ func TestHashID(t *testing.T) { asserts := assert.New(t) rec := httptest.NewRecorder() TestFunc := HashID(hashid.FolderID) + constant.HashIDTable = []int{0, 1, 2, 3, 4, 5, 6} // 未给定ID对象,跳过 { diff --git a/models/download.go b/models/download.go index f3b3ea2..da47b95 100644 --- a/models/download.go +++ b/models/download.go @@ -48,7 +48,11 @@ func (task *Download) AfterFind() (err error) { // BeforeSave Save下载任务前的钩子 func (task *Download) BeforeSave() (err error) { - return task.AfterFind() + // 解析状态 + if task.Attrs != "" { + err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) + } + return err } // Create 创建离线下载记录 diff --git a/models/migration.go b/models/migration.go index c6de4cd..33d568d 100644 --- a/models/migration.go +++ b/models/migration.go @@ -1,6 +1,7 @@ package model import ( + "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/util" "github.com/jinzhu/gorm" @@ -24,6 +25,11 @@ func migration() { util.Log().Info("开始进行数据库初始化...") + // 清除所有缓存 + if instance, ok := cache.Store.(*cache.RedisStore); ok { + instance.DeleteAll() + } + // 自动迁移模式 if conf.DatabaseConfig.Type == "mysql" { DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") @@ -54,9 +60,7 @@ func addDefaultPolicy() { defaultPolicy := Policy{ Name: "默认存储策略", Type: "local", - Server: "/api/v3/file/upload", - BaseURL: "http://cloudreve.org/public/uploads/", - MaxSize: 10 * 1024 * 1024 * 1024, + MaxSize: 0, AutoRename: true, DirNameRule: "uploads/{uid}/{path}", FileNameRule: "{uid}_{randomkey8}_{originname}", @@ -137,7 +141,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "shopid", Value: ``, Type: "payment"}, {Name: "hot_share_num", Value: `10`, Type: "share"}, {Name: "group_sell_data", Value: `[]`, Type: "group_sell"}, - {Name: "gravatar_server", Value: `https://gravatar.loli.net/`, Type: "avatar"}, + {Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"}, {Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"}, {Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"}, {Name: "aria2_token", Value: ``, Type: "aria2"}, @@ -169,9 +173,9 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "captcha_ComplexOfNoiseText", Value: "0", Type: "captcha"}, {Name: "captcha_ComplexOfNoiseDot", Value: "0", Type: "captcha"}, {Name: "captcha_IsShowHollowLine", Value: "0", Type: "captcha"}, - {Name: "captcha_IsShowNoiseDot", Value: "0", Type: "captcha"}, + {Name: "captcha_IsShowNoiseDot", Value: "1", Type: "captcha"}, {Name: "captcha_IsShowNoiseText", Value: "0", Type: "captcha"}, - {Name: "captcha_IsShowSlimeLine", Value: "0", Type: "captcha"}, + {Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"}, {Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"}, {Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"}, {Name: "thumb_width", Value: "400", Type: "thumb"}, diff --git a/models/policy.go b/models/policy.go index bd35f5a..87f8b01 100644 --- a/models/policy.go +++ b/models/policy.go @@ -102,13 +102,20 @@ func (policy *Policy) SerializeOptions() (err error) { func (policy *Policy) GeneratePath(uid uint, origin string) string { dirRule := policy.DirNameRule replaceTable := map[string]string{ - "{randomkey16}": util.RandStringRunes(16), - "{randomkey8}": util.RandStringRunes(8), - "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), - "{uid}": strconv.Itoa(int(uid)), - "{datetime}": time.Now().Format("20060102150405"), - "{date}": time.Now().Format("20060102"), - "{path}": origin + "/", + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), + "{uid}": strconv.Itoa(int(uid)), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + "{year}": time.Now().Format("2006"), + "{month}": time.Now().Format("01"), + "{day}": time.Now().Format("02"), + "{hour}": time.Now().Format("15"), + "{minute}": time.Now().Format("04"), + "{second}": time.Now().Format("05"), + "{path}": origin + "/", } dirRule = util.Replace(replaceTable, dirRule) return path.Clean(dirRule) @@ -124,12 +131,19 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string { fileRule := policy.FileNameRule replaceTable := map[string]string{ - "{randomkey16}": util.RandStringRunes(16), - "{randomkey8}": util.RandStringRunes(8), - "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), - "{uid}": strconv.Itoa(int(uid)), - "{datetime}": time.Now().Format("20060102150405"), - "{date}": time.Now().Format("20060102"), + "{randomkey16}": util.RandStringRunes(16), + "{randomkey8}": util.RandStringRunes(8), + "{timestamp}": strconv.FormatInt(time.Now().Unix(), 10), + "{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10), + "{uid}": strconv.Itoa(int(uid)), + "{datetime}": time.Now().Format("20060102150405"), + "{date}": time.Now().Format("20060102"), + "{year}": time.Now().Format("2006"), + "{month}": time.Now().Format("01"), + "{day}": time.Now().Format("02"), + "{hour}": time.Now().Format("15"), + "{minute}": time.Now().Format("04"), + "{second}": time.Now().Format("05"), } replaceTable["{originname}"] = policy.getOriginNameRule(origin) diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 413b5a2..6c2ad58 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -93,8 +93,10 @@ func Init(isReload bool) { // 关闭上个初始连接 if previousClient, ok := Instance.(*RPCService); ok { - util.Log().Debug("关闭上个 aria2 连接") - previousClient.caller.Close() + if previousClient.Caller != nil { + util.Log().Debug("关闭上个 aria2 连接") + previousClient.Caller.Close() + } } options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options") diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index 638792a..5ed2e59 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -14,7 +14,7 @@ import ( // RPCService 通过RPC服务的Aria2任务管理器 type RPCService struct { options *clientOptions - caller rpc.Client + Caller rpc.Client } type clientOptions struct { @@ -24,8 +24,8 @@ type clientOptions struct { // Init 初始化 func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error { // 客户端已存在,则关闭先前连接 - if client.caller != nil { - client.caller.Close() + if client.Caller != nil { + client.Caller.Close() } client.options = &clientOptions{ @@ -33,18 +33,18 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s } caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, EventNotifier) - client.caller = caller + client.Caller = caller return err } // Status 查询下载状态 func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { - res, err := client.caller.TellStatus(task.GID) + res, err := client.Caller.TellStatus(task.GID) if err != nil { // 失败后重试 util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) time.Sleep(time.Duration(10) * time.Second) - res, err = client.caller.TellStatus(task.GID) + res, err = client.Caller.TellStatus(task.GID) } return res, err @@ -53,7 +53,7 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { // Cancel 取消下载 func (client *RPCService) Cancel(task *model.Download) error { // 取消下载任务 - _, err := client.caller.Remove(task.GID) + _, err := client.Caller.Remove(task.GID) if err != nil { util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err) } @@ -79,7 +79,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error { for i := 0; i < len(files); i++ { selected[i] = strconv.Itoa(files[i]) } - _, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) + _, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) return err } @@ -103,7 +103,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri options[k] = v } - gid, err := client.caller.AddURI(task.Source, options) + gid, err := client.Caller.AddURI(task.Source, options) if err != nil || gid == "" { return err } diff --git a/pkg/aria2/monitor_test.go b/pkg/aria2/monitor_test.go index e27eae8..c04e89a 100644 --- a/pkg/aria2/monitor_test.go +++ b/pkg/aria2/monitor_test.go @@ -20,7 +20,7 @@ type InstanceMock struct { testMock.Mock } -func (m InstanceMock) CreateTask(task *model.Download, options []interface{}) error { +func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error { args := m.Called(task, options) return args.Error(0) } @@ -307,13 +307,16 @@ func TestMonitor_Complete(t *testing.T) { } cache.Set("setting_max_worker_num", "1", 0) + mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) task.Init() - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() + mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() asserts.True(monitor.Complete(rpc.StatusInfo{})) asserts.NoError(mock.ExpectationsWereMet()) diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 3517e2c..de0c45b 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -201,3 +201,16 @@ func (store *RedisStore) Delete(keys []string, prefix string) error { } return nil } + +// DeleteAll 批量所有键 +func (store *RedisStore) DeleteAll() error { + rc := store.pool.Get() + defer rc.Close() + if rc.Err() != nil { + return rc.Err() + } + + _, err := rc.Do("FLUSHDB") + + return err +} diff --git a/pkg/filesystem/driver/local/handler.go b/pkg/filesystem/driver/local/handler.go index aef80c5..bf83a29 100644 --- a/pkg/filesystem/driver/local/handler.go +++ b/pkg/filesystem/driver/local/handler.go @@ -124,6 +124,15 @@ func (handler Driver) Source( return "", errors.New("无法获取文件记录上下文") } + // 是否启用了CDN + if handler.Policy.BaseURL != "" { + cdnURL, err := url.Parse(handler.Policy.BaseURL) + if err != nil { + return "", err + } + baseURL = *cdnURL + } + var ( signedURI *url.URL err error diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index 749264f..a0ecd5a 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -117,3 +117,47 @@ func AdminDeleteRedeem(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// AdminTestAria2 测试aria2连接 +func AdminTestAria2(c *gin.Context) { + var service admin.Aria2TestService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Test() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminListPolicy 列出存储策略 +func AdminListPolicy(c *gin.Context) { + var service admin.AdminListService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Policies() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminTestPath 测试本地路径可用性 +func AdminTestPath(c *gin.Context) { + var service admin.PathTestService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Test() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminAddPolicy 新建存储策略 +func AdminAddPolicy(c *gin.Context) { + var service admin.AddPolicyService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Add() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 1d7e45e..e149e56 100644 --- a/routers/router.go +++ b/routers/router.go @@ -315,6 +315,23 @@ func InitMasterRouter() *gin.Engine { redeem.DELETE(":id", controllers.AdminDeleteRedeem) } + // 离线下载相关 + aria2 := admin.Group("aria2") + { + // 测试连接配置 + aria2.POST("test", controllers.AdminTestAria2) + } + + policy := admin.Group("policy") + { + // 列出存储策略 + policy.POST("list", controllers.AdminListPolicy) + // 测试本地路径可用性 + policy.POST("test/path", controllers.AdminTestPath) + // 创建存储策略 + policy.POST("", controllers.AdminAddPolicy) + } + } // 用户 diff --git a/service/admin/aria2.go b/service/admin/aria2.go new file mode 100644 index 0000000..358efac --- /dev/null +++ b/service/admin/aria2.go @@ -0,0 +1,42 @@ +package admin + +import ( + "github.com/HFO4/cloudreve/pkg/aria2" + "github.com/HFO4/cloudreve/pkg/serializer" + "net/url" +) + +// Aria2TestService aria2连接测试服务 +type Aria2TestService struct { + Server string `json:"server" binding:"required"` + Token string `json:"token"` +} + +// Test 测试aria2连接 +func (service *Aria2TestService) Test() serializer.Response { + testRPC := aria2.RPCService{} + + // 解析RPC服务地址 + server, err := url.Parse(service.Server) + if err != nil { + return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) + } + server.Path = "/jsonrpc" + + if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { + return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) + } + + defer testRPC.Caller.Close() + + info, err := testRPC.Caller.GetVersion() + if err != nil { + return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) + } + + if info.Version == "" { + return serializer.ParamErr("RPC 服务返回非预期响应", nil) + } + + return serializer.Response{Data: info.Version} +} diff --git a/service/admin/policy.go b/service/admin/policy.go new file mode 100644 index 0000000..c8e9ae2 --- /dev/null +++ b/service/admin/policy.go @@ -0,0 +1,81 @@ +package admin + +import ( + "fmt" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/util" + "os" + "path/filepath" +) + +// PathTestService 本地路径测试服务 +type PathTestService struct { + Path string `json:"path" binding:"required"` +} + +// AddPolicyService 存储策略添加服务 +type AddPolicyService struct { + Policy model.Policy `json:"policy" binding:"required"` +} + +// Add 添加存储策略 +func (service *AddPolicyService) Add() serializer.Response { + if err := model.DB.Create(&service.Policy).Error; err != nil { + return serializer.ParamErr("存储策略添加失败", err) + } + return serializer.Response{} +} + +// Test 测试本地路径 +func (service *PathTestService) Test() serializer.Response { + policy := model.Policy{DirNameRule: service.Path} + path := policy.GeneratePath(1, "/My File") + path = filepath.Join(path, "test.txt") + file, err := util.CreatNestedFile(path) + if err != nil { + return serializer.ParamErr(fmt.Sprintf("无法创建路径 %s , %s", path, err.Error()), nil) + } + + file.Close() + os.Remove(path) + + return serializer.Response{} +} + +// Policies 列出存储策略 +func (service *AdminListService) Policies() serializer.Response { + var res []model.Policy + total := 0 + + tx := model.DB.Model(&model.Policy{}) + if service.OrderBy != "" { + tx = tx.Order(service.OrderBy) + } + + for k, v := range service.Conditions { + tx = tx.Where("? = ?", k, v) + } + + // 计算总数用于分页 + tx.Count(&total) + + // 查询记录 + tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + + // 统计每个策略的文件使用 + statics := make(map[uint][2]int, len(res)) + for i := 0; i < len(res); i++ { + total := [2]int{} + row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID). + Select("count(id),sum(size)").Row() + row.Scan(&total[0], &total[1]) + statics[res[i].ID] = total + } + + return serializer.Response{Data: map[string]interface{}{ + "total": total, + "items": res, + "statics": statics, + }} +}