diff --git a/go.mod b/go.mod index e7233f5b..64df0eb5 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gin-gonic/gin v1.5.0 github.com/go-ini/ini v1.50.0 github.com/go-mail/mail v2.3.1+incompatible + github.com/gofrs/uuid v4.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 github.com/gorilla/websocket v1.4.1 diff --git a/go.sum b/go.sum index 280b0f86..87280644 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 7398fa8e..90b40389 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -284,7 +284,26 @@ func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interf } func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) { - panic("implement me") + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + } + + res, err := s.SendAria2Call(req, "status") + if err != nil { + return rpc.StatusInfo{}, err + } + + if res.Code != 0 { + return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res) + } + + var status rpc.StatusInfo + res.GobDecode(&status) + + return status, err } func (s *slaveCaller) Cancel(task *model.Download) error { diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 37e70c69..d5e971c1 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -5,14 +5,6 @@ import ( "github.com/gin-gonic/gin" ) -// Response 基础序列化器 -type Response struct { - Code int `json:"code"` - Data interface{} `json:"data,omitempty"` - Msg string `json:"msg"` - Error string `json:"error,omitempty"` -} - // AppError 应用错误,实现了error接口 type AppError struct { Code int diff --git a/pkg/serializer/response.go b/pkg/serializer/response.go new file mode 100644 index 00000000..91aae47f --- /dev/null +++ b/pkg/serializer/response.go @@ -0,0 +1,35 @@ +package serializer + +import ( + "bytes" + "encoding/base64" + "encoding/gob" +) + +// Response 基础序列化器 +type Response struct { + Code int `json:"code"` + Data interface{} `json:"data,omitempty"` + Msg string `json:"msg"` + Error string `json:"error,omitempty"` +} + +// NewResponseWithGobData 返回Data字段使用gob编码的Response +func NewResponseWithGobData(data interface{}) Response { + var w bytes.Buffer + encoder := gob.NewEncoder(&w) + if err := encoder.Encode(data); err != nil { + return Err(CodeInternalSetting, "无法编码返回结果", err) + } + + return Response{Data: w.Bytes()} +} + +// GobDecode 将 Response 正文解码至目标指针 +func (r *Response) GobDecode(target interface{}) { + src := r.Data.(string) + raw := make([]byte, len(src)*len(src)/base64.StdEncoding.DecodedLen(len(src))) + base64.StdEncoding.Decode(raw, []byte(src)) + decoder := gob.NewDecoder(bytes.NewBuffer(raw)) + decoder.Decode(target) +} diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index 55296663..c74a3dc5 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -1,8 +1,10 @@ package slave import ( + "encoding/gob" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -38,6 +40,7 @@ func Init() { DefaultController = &slaveController{ masters: make(map[string]masterInfo), } + gob.Register(rpc.StatusInfo{}) } func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) { diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 1c282244..7926a016 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -199,3 +199,14 @@ func SlaveAria2Create(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveAria2Status 查询 Aria2 任务状态 +func SlaveAria2Status(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.Status(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index 2ec2f049..e152406c 100644 --- a/routers/router.go +++ b/routers/router.go @@ -57,7 +57,10 @@ func InitSlaveRouter() *gin.Engine { // 离线下载 aria2 := v3.Group("aria2") { + // 创建离线下载任务 aria2.POST("task", controllers.SlaveAria2Create) + // 创建离线下载任务 + aria2.POST("status", controllers.SlaveAria2Status) } } return r diff --git a/service/aria2/manage.go b/service/aria2/manage.go index b606c71a..eb12a7a9 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -5,6 +5,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" ) @@ -91,3 +92,24 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response { return serializer.Response{} } + +// Status 从机查询离线任务状态 +func Status(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + if siteID, exist := c.Get("MasterSiteID"); exist { + // 获取对应主机节点的从机Aria2实例 + caller, err := slave.DefaultController.GetAria2Instance(siteID.(string)) + if err != nil { + return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err) + } + + // 查询任务 + status, err := caller.Status(service.Task) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err) + } + + return serializer.NewResponseWithGobData(status) + } + + return serializer.ParamErr("未知的主机节点ID", nil) +}