Feat: create and monitor aria2 task in master node

pull/1040/head
HFO4 4 years ago
parent 56590829d1
commit 97aaa35792

@ -24,6 +24,7 @@ type Download struct {
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
UserID uint // 发起者UID UserID uint // 发起者UID
TaskID uint // 对应的转存任务ID TaskID uint // 对应的转存任务ID
NodeID uint // 处理任务的节点ID
// 关联模型 // 关联模型
User *User `gorm:"PRELOAD:false,association_autoupdate:false"` User *User `gorm:"PRELOAD:false,association_autoupdate:false"`

@ -4,13 +4,13 @@ import (
"sync" "sync"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
) )
// Instance 默认使用的Aria2处理实例 // Instance 默认使用的Aria2处理实例
var Instance Aria2 = &DummyAria2{} var Instance common.Aria2 = &common.DummyAria2{}
// LB 获取 Aria2 节点的负载均衡器 // LB 获取 Aria2 节点的负载均衡器
var LB balancer.Balancer var LB balancer.Balancer
@ -18,82 +18,6 @@ var LB balancer.Balancer
// Lock Instance的读写锁 // Lock Instance的读写锁
var Lock sync.RWMutex var Lock sync.RWMutex
// EventNotifier 任务状态更新通知处理器
var EventNotifier = &Notifier{}
// Aria2 离线下载处理接口
type Aria2 interface {
// Init 初始化客户端连接
Init() error
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
func (instance *DummyAria2) Init() error {
return nil
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
return "", ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
}
// Init 初始化 // Init 初始化
func Init(isReload bool) { func Init(isReload bool) {
Lock.Lock() Lock.Lock()
@ -102,31 +26,11 @@ func Init(isReload bool) {
if !isReload { if !isReload {
// 从数据库中读取未完成任务,创建监控 // 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading) unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
for i := 0; i < len(unfinished); i++ { for i := 0; i < len(unfinished); i++ {
// 创建任务监控 // 创建任务监控
NewMonitor(&unfinished[i]) monitor.NewMonitor(&unfinished[i])
} }
} }
} }
// getStatus 将给定的状态字符串转换为状态标识数字
func getStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
}
}

@ -6,6 +6,7 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -37,7 +38,7 @@ func TestDummyAria2(t *testing.T) {
} }
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
MAX_RETRY = 0 monitor.MAX_RETRY = 0
asserts := assert.New(t) asserts := assert.New(t)
cache.Set("setting_aria2_token", "1", 0) cache.Set("setting_aria2_token", "1", 0)
cache.Set("setting_aria2_call_timeout", "5", 0) cache.Set("setting_aria2_call_timeout", "5", 0)
@ -81,11 +82,11 @@ func TestInit(t *testing.T) {
func TestGetStatus(t *testing.T) { func TestGetStatus(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
asserts.Equal(4, getStatus("complete")) asserts.Equal(4, GetStatus("complete"))
asserts.Equal(1, getStatus("active")) asserts.Equal(1, GetStatus("active"))
asserts.Equal(0, getStatus("waiting")) asserts.Equal(0, GetStatus("waiting"))
asserts.Equal(2, getStatus("paused")) asserts.Equal(2, GetStatus("paused"))
asserts.Equal(3, getStatus("error")) asserts.Equal(3, GetStatus("error"))
asserts.Equal(5, getStatus("removed")) asserts.Equal(5, GetStatus("removed"))
asserts.Equal(6, getStatus("?")) asserts.Equal(6, GetStatus("?"))
} }

@ -8,6 +8,7 @@ import (
"time" "time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
) )
@ -33,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s
Options: options, Options: options,
} }
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
EventNotifier) common.EventNotifier)
client.Caller = caller client.Caller = caller
return err return err
} }

@ -0,0 +1,110 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Aria2 离线下载处理接口
type Aria2 interface {
// Init 初始化客户端连接
Init() error
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
// GetConfig 获取离线下载配置
GetConfig() model.Aria2Option
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
func (instance *DummyAria2) Init() error {
return nil
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
return "", ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
}
// GetConfig 返回空的
func (instance *DummyAria2) GetConfig() model.Aria2Option {
return model.Aria2Option{}
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
}
}
// EventNotifier 任务状态更新通知处理器
var EventNotifier = &Notifier{}

@ -1,4 +1,4 @@
package aria2 package common
import ( import (
"sync" "sync"
@ -6,7 +6,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
) )
// Notifier aria2实践通知处理 // Notifier aria2事件通知处理
type Notifier struct { type Notifier struct {
Subscribes sync.Map Subscribes sync.Map
} }
@ -62,3 +62,9 @@ func (notifier *Notifier) OnDownloadError(events []rpc.Event) {
func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) {
notifier.Notify(events, Complete) notifier.Notify(events, Complete)
} }
// StatusEvent 状态改变事件
type StatusEvent struct {
GID string
Status int
}

@ -1,4 +1,4 @@
package aria2 package monitor
import ( import (
"context" "context"
@ -10,7 +10,9 @@ import (
"time" "time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
@ -23,32 +25,31 @@ type Monitor struct {
Task *model.Download Task *model.Download
Interval time.Duration Interval time.Duration
notifier chan StatusEvent notifier chan common.StatusEvent
node cluster.Node
retried int retried int
} }
// StatusEvent 状态改变事件
type StatusEvent struct {
GID string
Status int
}
var MAX_RETRY = 10 var MAX_RETRY = 10
// NewMonitor 新建上传状态监控 // NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download) { func NewMonitor(task *model.Download) {
monitor := &Monitor{ monitor := &Monitor{
Task: task, Task: task,
Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second, notifier: make(chan common.StatusEvent),
notifier: make(chan StatusEvent), node: cluster.Default.GetNodeByID(task.NodeID),
}
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop()
common.EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
} }
go monitor.Loop()
EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
} }
// Loop 开启监控循环 // Loop 开启监控循环
func (monitor *Monitor) Loop() { func (monitor *Monitor) Loop() {
defer EventNotifier.Unsubscribe(monitor.Task.GID) defer common.EventNotifier.Unsubscribe(monitor.Task.GID)
fmt.Println(cluster.Default)
// 首次循环立即更新 // 首次循环立即更新
interval := time.Duration(0) interval := time.Duration(0)
@ -70,9 +71,7 @@ func (monitor *Monitor) Loop() {
// Update 更新状态,返回值表示是否退出监控 // Update 更新状态,返回值表示是否退出监控
func (monitor *Monitor) Update() bool { func (monitor *Monitor) Update() bool {
Lock.RLock() status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
status, err := Instance.Status(monitor.Task)
Lock.RUnlock()
if err != nil { if err != nil {
monitor.retried++ monitor.retried++
@ -115,7 +114,7 @@ func (monitor *Monitor) Update() bool {
case "active", "waiting", "paused": case "active", "waiting", "paused":
return false return false
case "removed": case "removed":
monitor.Task.Status = Canceled monitor.Task.Status = common.Canceled
monitor.Task.Save() monitor.Task.Save()
monitor.RemoveTempFolder() monitor.RemoveTempFolder()
return true return true
@ -130,7 +129,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid monitor.Task.GID = status.Gid
monitor.Task.Status = getStatus(status.Status) monitor.Task.Status = common.GetStatus(status.Status)
// 文件大小、已下载大小 // 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64) total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@ -164,9 +163,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
// 文件大小更新后,对文件限制等进行校验 // 文件大小更新后,对文件限制等进行校验
if err := monitor.ValidateFile(); err != nil { if err := monitor.ValidateFile(); err != nil {
// 验证失败时取消任务 // 验证失败时取消任务
Lock.RLock() monitor.node.GetAria2Instance().Cancel(monitor.Task)
Instance.Cancel(monitor.Task)
Lock.RUnlock()
return err return err
} }
} }
@ -179,7 +176,7 @@ func (monitor *Monitor) ValidateFile() error {
// 找到任务创建者 // 找到任务创建者
user := monitor.Task.GetOwner() user := monitor.Task.GetOwner()
if user == nil { if user == nil {
return ErrUserNotFound return common.ErrUserNotFound
} }
// 创建文件系统 // 创建文件系统
@ -269,7 +266,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
} }
func (monitor *Monitor) setErrorStatus(err error) { func (monitor *Monitor) setErrorStatus(err error) {
monitor.Task.Status = Error monitor.Task.Status = common.Error
monitor.Task.Error = err.Error() monitor.Task.Error = err.Error()
monitor.Task.Save() monitor.Task.Save()
} }

@ -1,4 +1,4 @@
package aria2 package monitor
import ( import (
"errors" "errors"
@ -7,6 +7,8 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
@ -44,13 +46,13 @@ func (m InstanceMock) Select(task *model.Download, files []int) error {
func TestNewMonitor(t *testing.T) { func TestNewMonitor(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
NewMonitor(&model.Download{GID: "gid"}) NewMonitor(&model.Download{GID: "gid"})
_, ok := EventNotifier.Subscribes.Load("gid") _, ok := common.EventNotifier.Subscribes.Load("gid")
asserts.True(ok) asserts.True(ok)
} }
func TestMonitor_Loop(t *testing.T) { func TestMonitor_Loop(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
notifier := make(chan StatusEvent) notifier := make(chan common.StatusEvent)
MAX_RETRY = 0 MAX_RETRY = 0
monitor := &Monitor{ monitor := &Monitor{
Task: &model.Download{GID: "gid"}, Task: &model.Download{GID: "gid"},
@ -79,7 +81,7 @@ func TestMonitor_Update(t *testing.T) {
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error")) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1") file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close() file.Close()
Instance = testInstance aria2.Instance = testInstance
asserts.False(monitor.Update()) asserts.False(monitor.Update())
asserts.True(monitor.Update()) asserts.True(monitor.Update())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
@ -93,12 +95,12 @@ func TestMonitor_Update(t *testing.T) {
FollowedBy: []string{"1"}, FollowedBy: []string{"1"},
}, nil) }, nil)
monitor.Task.ID = 1 monitor.Task.ID = 1
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.False(monitor.Update()) asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
asserts.EqualValues("1", monitor.Task.GID) asserts.EqualValues("1", monitor.Task.GID)
} }
@ -108,12 +110,12 @@ func TestMonitor_Update(t *testing.T) {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.ID = 1 monitor.Task.ID = 1
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback() aria2.mock.ExpectRollback()
Instance = testInstance aria2.Instance = testInstance
asserts.True(monitor.Update()) asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -121,12 +123,12 @@ func TestMonitor_Update(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.True(monitor.Update()) asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -134,15 +136,15 @@ func TestMonitor_Update(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.True(monitor.Update()) asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -150,12 +152,12 @@ func TestMonitor_Update(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.False(monitor.Update()) asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -163,12 +165,12 @@ func TestMonitor_Update(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.True(monitor.Update()) asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -176,12 +178,12 @@ func TestMonitor_Update(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil) testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
Instance = testInstance aria2.Instance = testInstance
asserts.True(monitor.Update()) asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
} }
@ -198,21 +200,21 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
// 失败 // 失败
{ {
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback() aria2.mock.ExpectRollback()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err) asserts.Error(err)
} }
// 更新成功,无需校验 // 更新成功,无需校验
{ {
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{}) err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.NoError(err) asserts.NoError(err)
} }
@ -220,12 +222,12 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
{ {
testInstance := new(InstanceMock) testInstance := new(InstanceMock)
testInstance.On("Cancel", testMock.Anything).Return(nil) testInstance.On("Cancel", testMock.Anything).Return(nil)
Instance = testInstance aria2.Instance = testInstance
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"}) err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err) asserts.Error(err)
testInstance.AssertExpectations(t) testInstance.AssertExpectations(t)
} }
@ -308,17 +310,17 @@ func TestMonitor_Complete(t *testing.T) {
} }
cache.Set("setting_max_worker_num", "1", 0) cache.Set("setting_max_worker_num", "1", 0)
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"})) aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init() task.Init()
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
mock.ExpectBegin() aria2.mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() aria2.mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{})) asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(aria2.mock.ExpectationsWereMet())
} }

@ -3,13 +3,15 @@ package cluster
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/url" "net/url"
"path/filepath"
"strconv"
"strings"
"sync" "sync"
"time" "time"
) )
@ -49,6 +51,13 @@ func (node *MasterNode) Init(nodeModel *model.Node) {
node.lock.RUnlock() node.lock.RUnlock()
} }
func (node *MasterNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
return &serializer.NodePingResp{}, nil return &serializer.NodePingResp{}, nil
} }
@ -76,15 +85,15 @@ func (node *MasterNode) IsActive() bool {
} }
// GetAria2Instance 获取主机Aria2实例 // GetAria2Instance 获取主机Aria2实例
func (node *MasterNode) GetAria2Instance() aria2.Aria2 { func (node *MasterNode) GetAria2Instance() common.Aria2 {
if !node.Model.Aria2Enabled { if !node.Model.Aria2Enabled {
return &aria2.DummyAria2{} return &common.DummyAria2{}
} }
node.lock.RLock() node.lock.RLock()
defer node.lock.RUnlock() defer node.lock.RUnlock()
if !node.aria2RPC.Initialized { if !node.aria2RPC.Initialized {
return &aria2.DummyAria2{} return &common.DummyAria2{}
} }
return &node.aria2RPC return &node.aria2RPC
@ -122,25 +131,76 @@ func (r *rpcService) Init() error {
Options: globalOptions, Options: globalOptions,
} }
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, aria2.EventNotifier) caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, common.EventNotifier)
r.Caller = caller r.Caller = caller
r.Initialized = true r.Initialized = true
return err return err
} }
func (r *rpcService) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
return "", fmt.Errorf("some error #%d", r.parent.Model.ID) r.parent.lock.RLock()
// 生成存储路径
path := filepath.Join(
r.parent.Model.Aria2OptionsSerialized.TempPath,
"aria2",
strconv.FormatInt(time.Now().UnixNano(), 10),
)
r.parent.lock.RUnlock()
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range r.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := r.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return "", err
}
return gid, nil
} }
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
panic("implement me") res, err := r.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s10秒钟后重试", err)
time.Sleep(time.Duration(10) * time.Second)
res, err = r.Caller.TellStatus(task.GID)
}
return res, err
} }
func (r *rpcService) Cancel(task *model.Download) error { func (r *rpcService) Cancel(task *model.Download) error {
panic("implement me") // 取消下载任务
_, err := r.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
return err
} }
func (r *rpcService) Select(task *model.Download, files []int) error { func (r *rpcService) Select(task *model.Download, files []int) error {
panic("implement me") var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
func (r *rpcService) GetConfig() model.Aria2Option {
r.parent.lock.RLock()
defer r.parent.lock.RUnlock()
return r.parent.Model.Aria2OptionsSerialized
} }

@ -2,17 +2,25 @@ package cluster
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
) )
type Node interface { type Node interface {
// Init a node from database model
Init(node *model.Node) Init(node *model.Node)
// Check if given feature is enabled
IsFeatureEnabled(feature string) bool IsFeatureEnabled(feature string) bool
// Subscribe node status change to a callback function
SubscribeStatusChange(callback func(isActive bool, id uint)) SubscribeStatusChange(callback func(isActive bool, id uint))
// Ping the node
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
// Returns if the node is active
IsActive() bool IsActive() bool
GetAria2Instance() aria2.Aria2 // Get instances for aria2 calls
GetAria2Instance() common.Aria2
// Returns unique id of this node
ID() uint
} }
func getNodeFromDBModel(node *model.Node) Node { func getNodeFromDBModel(node *model.Node) Node {

@ -14,7 +14,11 @@ var featureGroup = []string{"aria2"}
// Pool 节点池 // Pool 节点池
type Pool interface { type Pool interface {
// Returns active node selected by given feature and load balancer
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
} }
// NodePool 通用节点池 // NodePool 通用节点池
@ -53,6 +57,17 @@ func (pool *NodePool) buildIndexMap() {
pool.lock.Unlock() pool.lock.Unlock()
} }
func (pool *NodePool) GetNodeByID(id uint) Node {
pool.lock.RLock()
defer pool.lock.RUnlock()
if node, ok := pool.active[id]; ok {
return node
}
return pool.inactive[id]
}
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) { func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive) util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive)
pool.lock.Lock() pool.lock.Lock()

@ -4,7 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -185,6 +185,13 @@ loop:
} }
// GetAria2Instance 获取从机Aria2实例 // GetAria2Instance 获取从机Aria2实例
func (node *SlaveNode) GetAria2Instance() aria2.Aria2 { func (node *SlaveNode) GetAria2Instance() common.Aria2 {
return nil return nil
} }
func (node *SlaveNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}

@ -3,7 +3,7 @@ package controllers
import ( import (
"context" "context"
ariaCall "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/service/aria2" "github.com/cloudreve/Cloudreve/v3/service/aria2"
"github.com/cloudreve/Cloudreve/v3/service/explorer" "github.com/cloudreve/Cloudreve/v3/service/explorer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -13,7 +13,7 @@ import (
func AddAria2URL(c *gin.Context) { func AddAria2URL(c *gin.Context) {
var addService aria2.AddURLService var addService aria2.AddURLService
if err := c.ShouldBindJSON(&addService); err == nil { if err := c.ShouldBindJSON(&addService); err == nil {
res := addService.Add(c, ariaCall.URLTask) res := addService.Add(c, common.URLTask)
c.JSON(200, res) c.JSON(200, res)
} else { } else {
c.JSON(200, ErrorResponse(err)) c.JSON(200, ErrorResponse(err))
@ -52,7 +52,7 @@ func AddAria2Torrent(c *gin.Context) {
if err := c.ShouldBindJSON(&addService); err == nil { if err := c.ShouldBindJSON(&addService); err == nil {
addService.URL = res.Data.(string) addService.URL = res.Data.(string)
res := addService.Add(c, ariaCall.URLTask) res := addService.Add(c, common.URLTask)
c.JSON(200, res) c.JSON(200, res)
} else { } else {
c.JSON(200, ErrorResponse(err)) c.JSON(200, ErrorResponse(err))

@ -3,6 +3,8 @@ package aria2
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -36,7 +38,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
// 创建任务 // 创建任务
task := &model.Download{ task := &model.Download{
Status: aria2.Ready, Status: common.Ready,
Type: taskType, Type: taskType,
Dst: service.Dst, Dst: service.Dst,
UserID: fs.User.ID, UserID: fs.User.ID,
@ -61,14 +63,14 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
} }
task.GID = gid task.GID = gid
task.NodeID = node.ID()
_, err = task.Create() _, err = task.Create()
if err != nil { if err != nil {
return serializer.DBErr("任务创建失败", err) return serializer.DBErr("任务创建失败", err)
} }
// 创建任务监控 // 创建任务监控
aria2.NewMonitor(task) monitor.NewMonitor(task)
aria2.Lock.RUnlock()
return serializer.Response{} return serializer.Response{}
} }

@ -3,6 +3,7 @@ package aria2
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -25,14 +26,14 @@ type DownloadListService struct {
// Finished 获取已完成的任务 // Finished 获取已完成的任务
func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录 // 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Error, aria2.Complete, aria2.Canceled, aria2.Unknown) downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown)
return serializer.BuildFinishedListResponse(downloads) return serializer.BuildFinishedListResponse(downloads)
} }
// Downloading 获取正在下载中的任务 // Downloading 获取正在下载中的任务
func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录 // 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Downloading, aria2.Paused, aria2.Ready) downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready)
return serializer.BuildDownloadingResponse(downloads) return serializer.BuildDownloadingResponse(downloads)
} }
@ -47,7 +48,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
} }
if download.Status >= aria2.Error { if download.Status >= common.Error {
// 如果任务已完成,则删除任务记录 // 如果任务已完成,则删除任务记录
if err := download.Delete(); err != nil { if err := download.Delete(); err != nil {
return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err) return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err)
@ -76,7 +77,7 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
} }
if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) { if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != common.Downloading && download.Status != common.Paused) {
return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err) return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err)
} }

Loading…
Cancel
Save