package admin import ( "bytes" "encoding/json" "fmt" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/filesystem/driver/oss" "github.com/HFO4/cloudreve/pkg/request" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "net/url" "os" "path/filepath" "time" ) // PathTestService 本地路径测试服务 type PathTestService struct { Path string `json:"path" binding:"required"` } // SlaveTestService 从机测试服务 type SlaveTestService struct { Secret string `json:"secret" binding:"required"` Server string `json:"server" binding:"required"` } // SlavePingService 从机相应ping type SlavePingService struct { Callback string `json:"callback" binding:"required"` } // AddPolicyService 存储策略添加服务 type AddPolicyService struct { Policy model.Policy `json:"policy" binding:"required"` } // PolicyService 存储策略ID服务 type PolicyService struct { ID uint `json:"id" binding:"required"` } // AddCORS 创建跨域策略 func (service *PolicyService) AddCORS() serializer.Response { policy, err := model.GetPolicyByID(service.ID) if err != nil { return serializer.Err(serializer.CodeNotFound, "存储策略不存在", nil) } switch policy.Type { case "oss": handler := oss.Driver{ Policy: &policy, HTTPClient: request.HTTPClient{}, } if err := handler.CORS(); err != nil { return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) } default: return serializer.ParamErr("不支持此策略", nil) } return serializer.Response{} } // Test 从机响应ping func (service *SlavePingService) Test() serializer.Response { master, err := url.Parse(service.Callback) if err != nil { return serializer.ParamErr("无法解析主机站点地址,请检查主机 参数设置 - 站点信息 - 站点URL设置,"+err.Error(), nil) } controller, _ := url.Parse("/api/v3/site/ping") r := request.HTTPClient{} res, err := r.Request( "GET", master.ResolveReference(controller).String(), nil, request.WithTimeout(time.Duration(10)*time.Second), ).DecodeResponse() if err != nil { return serializer.ParamErr("从机无法向主机发送回调请求,请检查主机端 参数设置 - 站点信息 - 站点URL设置,并确保从机可以连接到此地址,"+err.Error(), nil) } if res.Data.(string) != conf.BackendVersion { return serializer.ParamErr("Cloudreve版本不一致,主机:"+res.Data.(string)+",从机:"+conf.BackendVersion, nil) } return serializer.Response{} } // Test 测试从机通信 func (service *SlaveTestService) Test() serializer.Response { slave, err := url.Parse(service.Server) if err != nil { return serializer.ParamErr("无法解析从机端地址,"+err.Error(), nil) } controller, _ := url.Parse("/api/v3/slave/ping") // 请求正文 body := map[string]string{ "callback": model.GetSiteURL().String(), } bodyByte, _ := json.Marshal(body) r := request.HTTPClient{} res, err := r.Request( "POST", slave.ResolveReference(controller).String(), bytes.NewReader(bodyByte), request.WithTimeout(time.Duration(10)*time.Second), request.WithCredential( auth.HMACAuth{SecretKey: []byte(service.Secret)}, int64(model.GetIntSetting("slave_api_timeout", 60)), ), ).DecodeResponse() if err != nil { return serializer.ParamErr("无连接到从机,"+err.Error(), nil) } if res.Code != 0 { return serializer.ParamErr("成功接到从机,但是"+res.Msg, nil) } return serializer.Response{} } // Add 添加存储策略 func (service *AddPolicyService) Add() serializer.Response { if err := model.DB.Create(&service.Policy).Error; err != nil { return serializer.ParamErr("存储策略添加失败", err) } return serializer.Response{Data: service.Policy.ID} } // Test 测试本地路径 func (service *PathTestService) Test() serializer.Response { policy := model.Policy{DirNameRule: service.Path} path := policy.GeneratePath(1, "/My File") path = filepath.Join(path, "test.txt") file, err := util.CreatNestedFile(path) if err != nil { return serializer.ParamErr(fmt.Sprintf("无法创建路径 %s , %s", path, err.Error()), nil) } file.Close() os.Remove(path) return serializer.Response{} } // Policies 列出存储策略 func (service *AdminListService) Policies() serializer.Response { var res []model.Policy total := 0 tx := model.DB.Model(&model.Policy{}) if service.OrderBy != "" { tx = tx.Order(service.OrderBy) } for k, v := range service.Conditions { tx = tx.Where(k+" = ?", v) } // 计算总数用于分页 tx.Count(&total) // 查询记录 tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) // 统计每个策略的文件使用 statics := make(map[uint][2]int, len(res)) for i := 0; i < len(res); i++ { total := [2]int{} row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID). Select("count(id),sum(size)").Row() row.Scan(&total[0], &total[1]) statics[res[i].ID] = total } return serializer.Response{Data: map[string]interface{}{ "total": total, "items": res, "statics": statics, }} }