Merge remote-tracking branch 'origin/master'

pull/1491/head
HFO4 2 years ago
commit f2c53dda31

@ -98,16 +98,7 @@ func Init(path string, statics fs.FS) {
}
for _, dependency := range dependencies {
switch dependency.mode {
case "master":
if conf.SystemConfig.Mode == "master" {
dependency.factory()
}
case "slave":
if conf.SystemConfig.Mode == "slave" {
dependency.factory()
}
default:
if dependency.mode == conf.SystemConfig.Mode || dependency.mode == "both" {
dependency.factory()
}
}

@ -38,7 +38,7 @@ require (
github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac
github.com/upyun/go-sdk v2.1.0+incompatible
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
gopkg.in/go-playground/validator.v9 v9.29.1
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
)
require (
@ -99,7 +99,6 @@ require (
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.12 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/mattn/go-sqlite3 v1.14.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@ -149,7 +148,6 @@ require (
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20211020174200-9d6173849985 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.org/x/tools v0.1.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/appengine v1.6.7 // indirect

@ -1393,10 +1393,8 @@ gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qS
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o=
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y=
gopkg.in/go-playground/validator.v9 v9.29.1 h1:SvGtYmN60a5CVKTOzMSyfzWDeZRxRuGvRQyEAKbw1xc=
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=

@ -1,14 +1,21 @@
package main
import (
"context"
_ "embed"
"flag"
"io"
"io/fs"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/cloudreve/Cloudreve/v3/bootstrap"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/routers"
@ -41,6 +48,9 @@ func init() {
}
func main() {
// 关闭数据库连接
defer model.DB.Close()
if isEject {
// 开始导出内置静态资源文件
bootstrap.Eject(staticFS)
@ -54,16 +64,35 @@ func main() {
}
api := routers.InitRouter()
server := &http.Server{Handler: api}
// 收到信号后关闭服务器
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
sig := <-sigChan
util.Log().Info("收到信号 %s开始关闭 server", sig)
ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
defer cancel()
}
err := server.Shutdown(ctx)
if err != nil {
util.Log().Error("关闭 server 错误, %s", err)
}
}()
// 如果启用了SSL
if conf.SSLConfig.CertPath != "" {
go func() {
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen)
if err := api.RunTLS(conf.SSLConfig.Listen,
conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SSLConfig.Listen, err)
}
}()
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen)
server.Addr = conf.SSLConfig.Listen
if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SSLConfig.Listen, err)
return
}
}
// 如果启用了Unix
@ -78,14 +107,26 @@ func main() {
api.TrustedPlatform = conf.UnixConfig.ProxyHeader
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
if err := RunUnix(server); err != nil {
util.Log().Error("无法监听[%s]%s", conf.UnixConfig.Listen, err)
}
return
}
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
if err := api.Run(conf.SystemConfig.Listen); err != nil {
server.Addr = conf.SystemConfig.Listen
if err := server.ListenAndServe(); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SystemConfig.Listen, err)
}
}
func RunUnix(server *http.Server) error {
listener, err := net.Listen("unix", conf.UnixConfig.Listen)
if err != nil {
return err
}
defer listener.Close()
defer os.Remove(conf.UnixConfig.Listen)
return server.Serve(listener)
}

@ -186,6 +186,10 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) {
// 结果值
filteredFiles := make([]File, 0)
if len(files) == 0 {
return filteredFiles, nil
}
// 查询软链接的文件
var filesWithSoftLinks []File
tx := DB

@ -257,6 +257,19 @@ func TestFile_GetPolicy(t *testing.T) {
}
}
func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) {
asserts := assert.New(t)
// 传入空
{
mock.ExpectQuery("SELECT(.+)files(.+)")
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.Error(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(len(file), 0)
DB.Find(&File{})
}
}
func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts := assert.New(t)
files := []File{

@ -3,8 +3,6 @@ package aria2
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url"
"sync"
"time"
@ -14,6 +12,8 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
)
// Instance 默认使用的Aria2处理实例
@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控

@ -46,6 +46,8 @@ const (
Canceled
// Unknown 未知状态
Unknown
// Seeding 做种中
Seeding
)
var (
@ -94,11 +96,14 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status string) int {
switch status {
func GetStatus(status rpc.StatusInfo) int {
switch status.Status {
case "complete":
return Complete
case "active":
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
return Seeding
}
return Downloading
case "waiting":
return Ready

@ -1,9 +1,11 @@
package common
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDummyAria2(t *testing.T) {
@ -35,11 +37,18 @@ func TestDummyAria2(t *testing.T) {
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus("complete"), Complete)
a.Equal(GetStatus("active"), Downloading)
a.Equal(GetStatus("waiting"), Ready)
a.Equal(GetStatus("paused"), Paused)
a.Equal(GetStatus("error"), Error)
a.Equal(GetStatus("removed"), Canceled)
a.Equal(GetStatus("unknown"), Unknown)
a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
TotalLength: "100", CompletedLength: "50"}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
TotalLength: "100", CompletedLength: "100"}), Seeding)
a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
}

@ -110,14 +110,14 @@ func (monitor *Monitor) Update() bool {
util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status)
switch status.Status {
case "complete":
switch common.GetStatus(status) {
case common.Complete, common.Seeding:
return monitor.Complete(task.TaskPoll)
case "error":
case common.Error:
return monitor.Error(status)
case "active", "waiting", "paused":
case common.Downloading, common.Ready, common.Paused:
return false
case "removed":
case common.Canceled:
monitor.Task.Status = common.Canceled
monitor.Task.Save()
monitor.RemoveTempFolder()
@ -133,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid
monitor.Task.Status = common.GetStatus(status.Status)
monitor.Task.Status = common.GetStatus(status)
// 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@ -236,6 +236,40 @@ func (monitor *Monitor) RemoveTempFolder() {
// Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(pool task.Pool) bool {
// 未开始转存,提交转存任务
if monitor.Task.TaskID == 0 {
return monitor.transfer(pool)
}
// 做种完成
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 转存完成,回收下载目录
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
job, err := task.NewRecycleTask(monitor.Task)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交回收任务
pool.Submit(job)
return true
}
}
return false
}
func (monitor *Monitor) transfer(pool task.Pool) bool {
// 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
@ -270,7 +304,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool {
monitor.Task.TaskID = job.Model().ID
monitor.Task.Save()
return true
return false
}
func (monitor *Monitor) setErrorStatus(err error) {

@ -3,6 +3,8 @@ package monitor
import (
"database/sql"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
@ -13,7 +15,6 @@ import (
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
)
var mock sqlmock.Sqlmock
@ -431,6 +432,14 @@ func TestMonitor_Complete(t *testing.T) {
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
a.False(m.Complete(mockPool))
m.Task.StatusInfo.Status = "complete"
a.True(m.Complete(mockPool))
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)

@ -4,35 +4,27 @@ package rpc
// StatusInfo represents response of aria2.tellStatus
type StatusInfo struct {
Gid string `json:"gid"` // GID of the download.
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec.
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
PieceLength string `json:"pieceLength"` // Piece length in bytes.
NumPieces string `json:"numPieces"` // The number of pieces.
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
} `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
Gid string `json:"gid"` // GID of the download.
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec.
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
PieceLength string `json:"pieceLength"` // Piece length in bytes.
NumPieces string `json:"numPieces"` // The number of pieces.
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
}
// URIInfo represents an element of response of aria2.getUris
@ -100,3 +92,13 @@ type Method struct {
Name string `json:"methodName"` // Method name to call
Params []interface{} `json:"params"` // Array containing parameters to the method call
}
type BitTorrentInfo struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
}

@ -3,7 +3,7 @@ package conf
import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-ini/ini"
"gopkg.in/go-playground/validator.v9"
"github.com/go-playground/validator/v10"
)
// database 数据库
@ -26,6 +26,7 @@ type system struct {
Debug bool
SessionSecret string
HashIDSalt string
GracePeriod int `validate:"gte=0"`
}
type ssl struct {

@ -5,6 +5,9 @@ import (
"context"
"encoding/json"
"errors"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
@ -13,8 +16,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/url"
"time"
)
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
@ -118,6 +119,6 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

@ -4,6 +4,7 @@ import (
"crypto/sha1"
"encoding/gob"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
)

@ -1,9 +1,10 @@
package serializer
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSlaveTransferReq_Hash(t *testing.T) {
@ -18,3 +19,14 @@ func TestSlaveTransferReq_Hash(t *testing.T) {
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}
func TestSlaveRecycleReq_Hash(t *testing.T) {
a := assert.New(t)
s1 := &SlaveRecycleReq{
Path: "1",
}
s2 := &SlaveRecycleReq{
Path: "2",
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}

@ -15,6 +15,8 @@ const (
TransferTaskType
// ImportTaskType 导入任务
ImportTaskType
// RecycleTaskType 回收任务
RecycleTaskType
)
// 任务状态
@ -113,6 +115,8 @@ func GetJobFromModel(task *model.Task) (Job, error) {
return NewTransferTaskFromModel(task)
case ImportTaskType:
return NewImportTaskFromModel(task)
case RecycleTaskType:
return NewRecycleTaskFromModel(task)
default:
return nil, ErrUnknownTaskType
}

@ -2,12 +2,12 @@ package task
import (
"errors"
testMock "github.com/stretchr/testify/mock"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestRecord(t *testing.T) {
@ -103,4 +103,16 @@ func TestGetJobFromModel(t *testing.T) {
asserts.Nil(job)
asserts.Error(err)
}
// RecycleTaskType
{
task := &model.Task{
Status: 0,
Type: RecycleTaskType,
}
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
job, err := GetJobFromModel(task)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Nil(job)
asserts.Error(err)
}
}

@ -0,0 +1,130 @@
package task
import (
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// RecycleTask 文件回收任务
type RecycleTask struct {
User *model.User
TaskModel *model.Task
TaskProps RecycleProps
Err *JobError
}
// RecycleProps 回收任务属性
type RecycleProps struct {
// 下载任务 GID
DownloadGID string `json:"download_gid"`
}
// Props 获取任务属性
func (job *RecycleTask) Props() string {
res, _ := json.Marshal(job.TaskProps)
return string(res)
}
// Type 获取任务状态
func (job *RecycleTask) Type() int {
return RecycleTaskType
}
// Creator 获取创建者ID
func (job *RecycleTask) Creator() uint {
return job.User.ID
}
// Model 获取任务的数据库模型
func (job *RecycleTask) Model() *model.Task {
return job.TaskModel
}
// SetStatus 设定状态
func (job *RecycleTask) SetStatus(status int) {
job.TaskModel.SetStatus(status)
}
// SetError 设定任务失败信息
func (job *RecycleTask) SetError(err *JobError) {
job.Err = err
res, _ := json.Marshal(job.Err)
job.TaskModel.SetError(string(res))
}
// SetErrorMsg 设定任务失败信息
func (job *RecycleTask) SetErrorMsg(msg string, err error) {
jobErr := &JobError{Msg: msg}
if err != nil {
jobErr.Error = err.Error()
}
job.SetError(jobErr)
}
// GetError 返回任务失败信息
func (job *RecycleTask) GetError() *JobError {
return job.Err
}
// Do 开始执行任务
func (job *RecycleTask) Do() {
download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID)
if err != nil {
util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID)
job.SetErrorMsg("无法找到下载任务", err)
return
}
nodeID := download.GetNodeID()
node := cluster.Default.GetNodeByID(nodeID)
if node == nil {
util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID)
job.SetErrorMsg("从机节点不可用", nil)
return
}
err = node.GetAria2Instance().DeleteTempFile(download)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err)
job.SetErrorMsg("文件回收失败", err)
return
}
}
// NewRecycleTask 新建回收任务
func NewRecycleTask(download *model.Download) (Job, error) {
newTask := &RecycleTask{
User: download.GetOwner(),
TaskProps: RecycleProps{
DownloadGID: download.GID,
},
}
record, err := Record(newTask)
if err != nil {
return nil, err
}
newTask.TaskModel = record
return newTask, nil
}
// NewRecycleTaskFromModel 从数据库记录中恢复回收任务
func NewRecycleTaskFromModel(task *model.Task) (Job, error) {
user, err := model.GetActiveUserByID(task.UserID)
if err != nil {
return nil, err
}
newTask := &RecycleTask{
User: &user,
TaskModel: task,
}
err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps)
if err != nil {
return nil, err
}
return newTask, nil
}

@ -0,0 +1,117 @@
package task
import (
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestRecycleTask_Props(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
}
asserts.NotEmpty(task.Props())
asserts.Equal(RecycleTaskType, task.Type())
asserts.EqualValues(0, task.Creator())
asserts.Nil(task.Model())
}
func TestRecycleTask_SetStatus(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
TaskModel: &model.Task{
Model: gorm.Model{ID: 1},
},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task.SetStatus(3)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestRecycleTask_SetError(t *testing.T) {
asserts := assert.New(t)
task := &RecycleTask{
User: &model.User{},
TaskModel: &model.Task{
Model: gorm.Model{ID: 1},
},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task.SetErrorMsg("error", nil)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal("error", task.GetError().Msg)
}
func TestNewRecycleTask(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
job, err := NewRecycleTask(&model.Download{
Model: gorm.Model{ID: 1},
GID: "test_g_id",
Parent: "/",
UserID: 1,
NodeID: 1,
})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotNil(job)
asserts.NoError(err)
}
// 失败
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
job, err := NewRecycleTask(&model.Download{
Model: gorm.Model{ID: 1},
GID: "test_g_id",
Parent: "test/not_exist",
UserID: 1,
NodeID: 1,
})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Nil(job)
asserts.Error(err)
}
}
func TestNewRecycleTaskFromModel(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.NotNil(job)
}
// JSON解析失败
{
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Nil(job)
}
}

@ -2,6 +2,8 @@ package slavetask
import (
"context"
"os"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
@ -10,7 +12,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"os"
)
// TransferTask 文件中转任务
@ -79,8 +80,6 @@ func (job *TransferTask) GetError() *task.JobError {
// Do 开始执行任务
func (job *TransferTask) Do() {
defer job.Recycle()
fs, err := filesystem.NewAnonymousFileSystem()
if err != nil {
job.SetErrorMsg("无法初始化匿名文件系统", err)
@ -137,11 +136,3 @@ func (job *TransferTask) Do() {
util.Log().Warning("无法发送转存成功通知到从机, %s", err)
}
}
// Recycle 回收临时文件
func (job *TransferTask) Recycle() {
err := os.Remove(job.Req.Src)
if err != nil {
util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Src, err)
}
}

@ -3,7 +3,6 @@ package task
import (
"context"
"encoding/json"
"os"
"path"
"path/filepath"
"strings"
@ -87,8 +86,6 @@ func (job *TransferTask) GetError() *JobError {
// Do 开始执行任务
func (job *TransferTask) Do() {
defer job.Recycle()
// 创建文件系统
fs, err := filesystem.NewFileSystem(job.User)
if err != nil {
@ -139,16 +136,6 @@ func (job *TransferTask) Do() {
}
// Recycle 回收临时文件
func (job *TransferTask) Recycle() {
if job.TaskProps.NodeID == 1 {
err := os.RemoveAll(job.TaskProps.Parent)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err)
}
}
}
// NewTransferTask 新建中转任务
func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) {
creator, err := model.GetActiveUserByID(user)

@ -33,7 +33,7 @@ func (service *DownloadListService) Finished(c *gin.Context, user *model.User) s
// Downloading 获取正在下载中的任务
func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready)
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready)
intervals := make(map[uint]int)
for _, download := range downloads {
if _, ok := intervals[download.ID]; !ok {
@ -57,7 +57,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
return serializer.Err(serializer.CodeNotFound, "Download record not found", err)
}
if download.Status >= common.Error {
if download.Status >= common.Error && download.Status <= common.Unknown {
// 如果任务已完成,则删除任务记录
if err := download.Delete(); err != nil {
return serializer.DBErr("Failed to delete task record", err)

@ -5,6 +5,10 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
@ -16,9 +20,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"net/http"
"net/url"
"time"
)
// SlaveDownloadService 从机文件下載服务

@ -1,14 +1,15 @@
package user
import (
"net/url"
"strings"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
"net/url"
"strings"
)
// UserRegisterService 管理用户注册的服务

Loading…
Cancel
Save