Feat: task queue / compression task

pull/247/head
HFO4 5 years ago
parent b1490a665c
commit e722c33cd5

@ -6,6 +6,7 @@ import (
"github.com/HFO4/cloudreve/pkg/authn" "github.com/HFO4/cloudreve/pkg/authn"
"github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf" "github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/task"
"github.com/HFO4/cloudreve/routers" "github.com/HFO4/cloudreve/routers"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -20,6 +21,7 @@ func init() {
if conf.SystemConfig.Mode == "master" { if conf.SystemConfig.Mode == "master" {
model.Init() model.Init()
authn.Init() authn.Init()
task.Init()
} }
auth.Init() auth.Init()
} }

@ -136,7 +136,7 @@ func TestBeforeShareDownload(t *testing.T) {
c.Set("user", &model.User{ c.Set("user", &model.User{
Model: gorm.Model{ID: 1}, Model: gorm.Model{ID: 1},
Group: model.Group{OptionsSerialized: model.GroupOption{ Group: model.Group{OptionsSerialized: model.GroupOption{
ShareDownloadEnabled: true, ShareDownload: true,
}}, }},
}) })
testFunc(c) testFunc(c)

@ -16,7 +16,7 @@ type Group struct {
Aria2Option string Aria2Option string
Color string Color string
SpeedLimit int SpeedLimit int
Options string `json:"-",gorm:"size:4096"` Options string `json:"-",gorm:"type:text"`
// 数据库忽略字段 // 数据库忽略字段
PolicyList []uint `gorm:"-"` PolicyList []uint `gorm:"-"`
@ -25,11 +25,11 @@ type Group struct {
// GroupOption 用户组其他配置 // GroupOption 用户组其他配置
type GroupOption struct { type GroupOption struct {
ArchiveDownloadEnabled bool `json:"archive_download,omitempty"` ArchiveDownload bool `json:"archive_download,omitempty"`
ArchiveTaskEnabled bool `json:"archive_task,omitempty"` ArchiveTask bool `json:"archive_task,omitempty"`
OneTimeDownloadEnabled bool `json:"one_time_download,omitempty"` OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownloadEnabled bool `json:"share_download,omitempty"` ShareDownload bool `json:"share_download,omitempty"`
ShareFreeEnabled bool `json:"share_free,omitempty"` ShareFree bool `json:"share_free,omitempty"`
} }
// GetAria2Option 获取用户离线下载设备 // GetAria2Option 获取用户离线下载设备

@ -47,7 +47,7 @@ func Init() {
// Debug模式下输出所有 SQL 日志 // Debug模式下输出所有 SQL 日志
if conf.SystemConfig.Debug { if conf.SystemConfig.Debug {
db.LogMode(true) db.LogMode(false)
} }
//db.SetLogger(util.Log()) //db.SetLogger(util.Log())

@ -29,7 +29,7 @@ func migration() {
if conf.DatabaseConfig.Type == "mysql" { if conf.DatabaseConfig.Type == "mysql" {
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
} }
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}) DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}, &Task{})
// 创建初始存储策略 // 创建初始存储策略
addDefaultPolicy() addDefaultPolicy()
@ -159,7 +159,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "aria2_token", Value: `your token`, Type: "aria2"}, {Name: "aria2_token", Value: `your token`, Type: "aria2"},
{Name: "aria2_rpcurl", Value: `http://127.0.0.1:6800/`, Type: "aria2"}, {Name: "aria2_rpcurl", Value: `http://127.0.0.1:6800/`, Type: "aria2"},
{Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"}, {Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"},
{Name: "task_queue_token", Value: ``, Type: "task"}, {Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
{Name: "temp_path", Value: "temp", Type: "path"}, {Name: "temp_path", Value: "temp", Type: "path"},
{Name: "score_enabled", Value: "1", Type: "score"}, {Name: "score_enabled", Value: "1", Type: "score"},
@ -186,9 +186,9 @@ func addDefaultGroups() {
WebDAVEnabled: true, WebDAVEnabled: true,
Aria2Option: "0,0,0", Aria2Option: "0,0,0",
OptionsSerialized: GroupOption{ OptionsSerialized: GroupOption{
ArchiveDownloadEnabled: true, ArchiveDownload: true,
ArchiveTaskEnabled: true, ArchiveTask: true,
ShareDownloadEnabled: true, ShareDownload: true,
}, },
} }
if err := DB.Create(&defaultAdminGroup).Error; err != nil { if err := DB.Create(&defaultAdminGroup).Error; err != nil {

@ -22,14 +22,14 @@ type Policy struct {
BucketName string BucketName string
IsPrivate bool IsPrivate bool
BaseURL string BaseURL string
AccessKey string `gorm:"size:1024"` AccessKey string `gorm:"type:text"`
SecretKey string `gorm:"size:1024"` SecretKey string `gorm:"type:text"`
MaxSize uint64 MaxSize uint64
AutoRename bool AutoRename bool
DirNameRule string DirNameRule string
FileNameRule string FileNameRule string
IsOriginLinkEnable bool IsOriginLinkEnable bool
Options string `gorm:"size:4096"` Options string `gorm:"type:text"`
// 数据库忽略字段 // 数据库忽略字段
OptionsSerialized PolicyOption `gorm:"-"` OptionsSerialized PolicyOption `gorm:"-"`

@ -123,7 +123,7 @@ func (share *Share) SourceFile() *File {
// CanBeDownloadBy 返回此分享是否可以被给定用户下载 // CanBeDownloadBy 返回此分享是否可以被给定用户下载
func (share *Share) CanBeDownloadBy(user *User) error { func (share *Share) CanBeDownloadBy(user *User) error {
// 用户组权限 // 用户组权限
if !user.Group.OptionsSerialized.ShareDownloadEnabled { if !user.Group.OptionsSerialized.ShareDownload {
if user.IsAnonymous() { if user.IsAnonymous() {
return errors.New("未登录用户无法下载") return errors.New("未登录用户无法下载")
} }
@ -169,7 +169,7 @@ func (share *Share) DownloadBy(user *User, c *gin.Context) error {
// Purchase 使用积分购买分享 // Purchase 使用积分购买分享
func (share *Share) Purchase(user *User) error { func (share *Share) Purchase(user *User) error {
// 不需要付积分 // 不需要付积分
if share.Score == 0 || user.Group.OptionsSerialized.ShareFreeEnabled || user.ID == share.UserID { if share.Score == 0 || user.Group.OptionsSerialized.ShareFree || user.ID == share.UserID {
return nil return nil
} }

@ -156,7 +156,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
user := &User{ user := &User{
Group: Group{ Group: Group{
OptionsSerialized: GroupOption{ OptionsSerialized: GroupOption{
ShareDownloadEnabled: false, ShareDownload: false,
}, },
}, },
} }
@ -169,7 +169,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
Model: gorm.Model{ID: 1}, Model: gorm.Model{ID: 1},
Group: Group{ Group: Group{
OptionsSerialized: GroupOption{ OptionsSerialized: GroupOption{
ShareDownloadEnabled: false, ShareDownload: false,
}, },
}, },
} }
@ -181,7 +181,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
user := &User{ user := &User{
Group: Group{ Group: Group{
OptionsSerialized: GroupOption{ OptionsSerialized: GroupOption{
ShareDownloadEnabled: true, ShareDownload: true,
}, },
}, },
} }
@ -195,7 +195,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
Model: gorm.Model{ID: 1}, Model: gorm.Model{ID: 1},
Group: Group{ Group: Group{
OptionsSerialized: GroupOption{ OptionsSerialized: GroupOption{
ShareDownloadEnabled: true, ShareDownload: true,
}, },
}, },
} }
@ -259,10 +259,10 @@ func TestShare_Purchase(t *testing.T) {
asserts.NoError(share.Purchase(&user)) asserts.NoError(share.Purchase(&user))
share.Score = 1 share.Score = 1
user.Group.OptionsSerialized.ShareFreeEnabled = true user.Group.OptionsSerialized.ShareFree = true
asserts.NoError(share.Purchase(&user)) asserts.NoError(share.Purchase(&user))
user.Group.OptionsSerialized.ShareFreeEnabled = false user.Group.OptionsSerialized.ShareFree = false
share.UserID = 1 share.UserID = 1
user.ID = 1 user.ID = 1
asserts.NoError(share.Purchase(&user)) asserts.NoError(share.Purchase(&user))

@ -0,0 +1,48 @@
package model
import (
"github.com/HFO4/cloudreve/pkg/util"
"github.com/jinzhu/gorm"
)
// Task 任务模型
type Task struct {
gorm.Model
Status int // 任务状态
Type int // 任务类型
UserID uint // 发起者UID0表示为系统发起
Progress int // 进度
Error string // 错误信息
Props string `gorm:"type:text"` // 任务属性
}
// Create 创建任务记录
func (task *Task) Create() (uint, error) {
if err := DB.Create(task).Error; err != nil {
util.Log().Warning("无法插入任务记录, %s", err)
return 0, err
}
return task.ID, nil
}
// SetStatus 设定任务状态
func (task *Task) SetStatus(status int) error {
return DB.Model(task).Select("status").Updates(map[string]interface{}{"status": status}).Error
}
// SetProgress 设定任务进度
func (task *Task) SetProgress(progress int) error {
return DB.Model(task).Select("progress").Updates(map[string]interface{}{"progress": progress}).Error
}
// SetError 设定错误信息
func (task *Task) SetError(err string) error {
return DB.Model(task).Select("error").Updates(map[string]interface{}{"error": err}).Error
}
// GetTasksByStatus 根据状态检索任务
func GetTasksByStatus(status int) []Task {
var tasks []Task
DB.Where("status = ?", status).Find(&tasks)
return tasks
}

@ -37,8 +37,8 @@ type User struct {
TwoFactor string `json:"-"` TwoFactor string `json:"-"`
Delay int Delay int
Avatar string Avatar string
Options string `json:"-",gorm:"size:4096"` Options string `json:"-",gorm:"type:text"`
Authn string `gorm:"size:8192"` Authn string `gorm:"type:text"`
Score int Score int
// 关联模型 // 关联模型

@ -20,7 +20,7 @@ import (
*/ */
// Compress 创建给定目录和文件的压缩文件 // Compress 创建给定目录和文件的压缩文件
func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (string, error) { func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint, isArchive bool) (string, error) {
// 查找待压缩目录 // 查找待压缩目录
folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID) folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
if err != nil && len(folders) != 0 { if err != nil && len(folders) != 0 {
@ -66,8 +66,13 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (
} }
// 创建临时压缩文件 // 创建临时压缩文件
saveFolder := "archive"
if !isArchive {
saveFolder = "compress"
}
zipFilePath := filepath.Join( zipFilePath := filepath.Join(
model.GetSettingByName("temp_path"), model.GetSettingByName("temp_path"),
saveFolder,
fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()), fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
) )
zipFile, err := util.CreatNestedFile(zipFilePath) zipFile, err := util.CreatNestedFile(zipFilePath)
@ -91,7 +96,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled return "", ErrClientCanceled
default: default:
fs.doCompress(ctx, nil, &folders[i], zipWriter, true) fs.doCompress(ctx, nil, &folders[i], zipWriter, isArchive)
} }
} }
@ -102,7 +107,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled return "", ErrClientCanceled
default: default:
fs.doCompress(ctx, &files[i], nil, zipWriter, true) fs.doCompress(ctx, &files[i], nil, zipWriter, isArchive)
} }
} }

@ -51,7 +51,7 @@ func TestFileSystem_Compress(t *testing.T) {
// 查找上传策略 // 查找上传策略
asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1)) asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1))
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.NoError(err) asserts.NoError(err)
asserts.NotEmpty(zipFile) asserts.NotEmpty(zipFile)
asserts.Contains(zipFile, "archive_") asserts.Contains(zipFile, "archive_")
@ -76,7 +76,7 @@ func TestFileSystem_Compress(t *testing.T) {
) )
asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.Error(err) asserts.Error(err)
asserts.Empty(zipFile) asserts.Empty(zipFile)
} }
@ -100,7 +100,7 @@ func TestFileSystem_Compress(t *testing.T) {
) )
asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.Error(err) asserts.Error(err)
asserts.Equal(ErrObjectNotExist, err) asserts.Equal(ErrObjectNotExist, err)
asserts.Empty(zipFile) asserts.Empty(zipFile)

@ -1,190 +0,0 @@
package filesystem
import (
"context"
"errors"
"fmt"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/cache"
"github.com/HFO4/cloudreve/pkg/conf"
"github.com/HFO4/cloudreve/pkg/filesystem/response"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"image"
"image/jpeg"
"os"
"testing"
)
func CreateTestImage() *os.File {
file, err := os.Create("TestFileSystem_GenerateThumbnail.jpeg")
alpha := image.NewAlpha(image.Rect(0, 0, 500, 200))
jpeg.Encode(file, alpha, nil)
if err != nil {
fmt.Println(err)
}
_, _ = file.Seek(0, 0)
return file
}
func TestFileSystem_GetThumb(t *testing.T) {
asserts := assert.New(t)
ctx := context.Background()
cache.Set("setting_preview_timeout", "60", 0)
// 正常
{
fs := FileSystem{
User: &model.User{
Model: gorm.Model{ID: 1},
},
}
testHandler := new(FileHeaderMock)
testHandler.On("Thumb", testMock.Anything, "123.jpg").Return(&response.ContentResponse{URL: "123"}, nil)
fs.Handler = testHandler
mock.ExpectQuery("SELECT(.+)").
WithArgs(10, 1).
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "pic_info", "source_name", "policy_id"}).
AddRow(10, "10,10", "123.jpg", 154),
)
mock.ExpectQuery("SELECT(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "type"}).
AddRow(154, "mock"),
)
res, err := fs.GetThumb(ctx, 10)
asserts.NoError(mock.ExpectationsWereMet())
testHandler.AssertExpectations(t)
asserts.NoError(err)
asserts.Equal("123", res.URL)
}
// 文件不存在
{
fs := FileSystem{
User: &model.User{
Model: gorm.Model{ID: 1},
},
}
mock.ExpectQuery("SELECT(.+)").
WithArgs(10, 1).
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "pic_info", "source_name"}),
)
_, err := fs.GetThumb(ctx, 10)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestFileSystem_GenerateThumbnail(t *testing.T) {
asserts := assert.New(t)
fs := FileSystem{
User: &model.User{
Model: gorm.Model{ID: 1},
},
}
ctx := context.Background()
// 成功
{
src := CreateTestImage()
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil)
fs.Handler = testHandler
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
file := &model.File{
Model: gorm.Model{ID: 1},
Name: "123.jpg",
SourceName: "TestFileSystem_GenerateThumbnail.jpeg",
}
fs.GenerateThumbnail(ctx, file)
asserts.NoError(mock.ExpectationsWereMet())
testHandler.AssertExpectations(t)
asserts.True(util.Exists("TestFileSystem_GenerateThumbnail.jpeg" + conf.ThumbConfig.FileSuffix))
}
// 成功,不进行数据库更新
{
src := CreateTestImage()
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil)
fs.Handler = testHandler
file := &model.File{
Name: "123.jpg",
SourceName: "TestFileSystem_GenerateThumbnail.jpeg",
}
fs.GenerateThumbnail(ctx, file)
asserts.NoError(mock.ExpectationsWereMet())
testHandler.AssertExpectations(t)
asserts.True(util.Exists("TestFileSystem_GenerateThumbnail.jpeg" + conf.ThumbConfig.FileSuffix))
}
// 更新信息失败后删除文件
{
src := CreateTestImage()
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil)
testHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, nil)
fs.Handler = testHandler
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
file := &model.File{
Model: gorm.Model{ID: 1},
Name: "123.jpg",
SourceName: "TestFileSystem_GenerateThumbnail.jpeg",
}
fs.GenerateThumbnail(ctx, file)
asserts.NoError(mock.ExpectationsWereMet())
testHandler.AssertExpectations(t)
}
// 不能生成缩略图
{
file := &model.File{
Model: gorm.Model{ID: 1},
Name: "123.123",
SourceName: "TestFileSystem_GenerateThumbnail.jpeg",
}
fs.GenerateThumbnail(ctx, file)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestFileSystem_GenerateThumbnailSize(t *testing.T) {
asserts := assert.New(t)
fs := FileSystem{
User: &model.User{
Model: gorm.Model{ID: 1},
},
}
asserts.NotPanics(func() {
_, _ = fs.GenerateThumbnailSize(0, 0)
})
}

@ -9,6 +9,7 @@ import (
"github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"os"
"path" "path"
) )
@ -114,8 +115,10 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe
var reqContext context.Context var reqContext context.Context
if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok { if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok {
reqContext = ginCtx.Request.Context() reqContext = ginCtx.Request.Context()
} else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok {
reqContext = reqCtx
} else { } else {
reqContext = ctx.Value(fsctx.HTTPCtx).(context.Context) return
} }
select { select {
@ -182,3 +185,42 @@ func (fs *FileSystem) GetUploadToken(ctx context.Context, path string, size uint
return &credential, nil return &credential, nil
} }
// UploadFromPath 将本机已有文件上传到用户的文件系统
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string) error {
file, err := os.Open(src)
if err != nil {
return err
}
defer file.Close()
// 获取源文件大小
fi, err := file.Stat()
if err != nil {
return err
}
size := fi.Size()
// 构建文件头
fileName := path.Base(dst)
filePath := path.Dir(dst)
fileData := local.FileStream{
File: file,
Size: uint64(size),
Name: fileName,
VirtualPath: filePath,
}
// 给文件系统分配钩子
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
fs.Use("AfterUploadCanceled", HookDeleteTempFile)
fs.Use("AfterUploadCanceled", HookGiveBackCapacity)
fs.Use("AfterUpload", GenericAfterUpload)
fs.Use("AfterValidateFailed", HookDeleteTempFile)
fs.Use("AfterValidateFailed", HookGiveBackCapacity)
fs.Use("AfterUploadFailed", HookGiveBackCapacity)
// 开始上传
return fs.Upload(ctx, fileData)
}

@ -77,9 +77,9 @@ func BuildUser(user model.User) User {
AllowShare: user.Group.ShareEnabled, AllowShare: user.Group.ShareEnabled,
AllowRemoteDownload: aria2Option[0], AllowRemoteDownload: aria2Option[0],
AllowTorrentDownload: aria2Option[2], AllowTorrentDownload: aria2Option[2],
AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownloadEnabled, AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownload,
ShareFreeEnabled: user.Group.OptionsSerialized.ShareFreeEnabled, ShareFreeEnabled: user.Group.OptionsSerialized.ShareFree,
ShareDownload: user.Group.OptionsSerialized.ShareDownloadEnabled, ShareDownload: user.Group.OptionsSerialized.ShareDownload,
}, },
} }
} }

@ -0,0 +1,155 @@
package task
import (
"context"
"encoding/json"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/util"
"os"
)
// CompressTask 文件压缩任务
type CompressTask struct {
User *model.User
TaskModel *model.Task
TaskProps CompressProps
Err *JobError
zipPath string
}
// CompressProps 压缩任务属性
type CompressProps struct {
Dirs []uint `json:"dirs"`
Files []uint `json:"files"`
Dst string `json:"dst"`
}
// Props 获取任务属性
func (job *CompressTask) Props() string {
res, _ := json.Marshal(job.TaskProps)
return string(res)
}
// Type 获取任务状态
func (job *CompressTask) Type() int {
return CompressTaskType
}
// Creator 获取创建者ID
func (job *CompressTask) Creator() uint {
return job.User.ID
}
// Model 获取任务的数据库模型
func (job *CompressTask) Model() *model.Task {
return job.TaskModel
}
// SetStatus 设定状态
func (job *CompressTask) SetStatus(status int) {
job.TaskModel.SetStatus(status)
}
// SetError 设定任务失败信息
func (job *CompressTask) SetError(err *JobError) {
job.Err = err
res, _ := json.Marshal(job.Err)
job.TaskModel.SetError(string(res))
// 删除压缩文件
job.removeZipFile()
}
func (job *CompressTask) removeZipFile() {
if job.zipPath != "" {
if err := os.Remove(job.zipPath); err != nil {
util.Log().Warning("无法删除临时压缩文件 %s , %s", job.zipPath, err)
}
}
}
// SetErrorMsg 设定任务失败信息
func (job *CompressTask) SetErrorMsg(msg string) {
job.SetError(&JobError{Msg: msg})
}
// GetError 返回任务失败信息
func (job *CompressTask) GetError() *JobError {
return job.Err
}
// Do 开始执行任务
func (job *CompressTask) Do() {
// 创建文件系统
fs, err := filesystem.NewFileSystem(job.User)
if err != nil {
job.SetErrorMsg(err.Error())
return
}
defer fs.Recycle()
util.Log().Debug("开始压缩文件")
job.TaskModel.SetProgress(CompressingProgress)
// 开始压缩
ctx := context.Background()
zipFile, err := fs.Compress(ctx, job.TaskProps.Dirs, job.TaskProps.Files, false)
if err != nil {
job.SetErrorMsg(err.Error())
return
}
job.zipPath = zipFile
util.Log().Debug("压缩文件存放至%s开始上传", zipFile)
job.TaskModel.SetProgress(TransferringProgress)
// 上传文件
err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst)
if err != nil {
job.SetErrorMsg(err.Error())
return
}
job.removeZipFile()
}
// NewCompressTask 新建压缩任务
func NewCompressTask(user *model.User, dst string, dirs, files []uint) (Job, error) {
newTask := &CompressTask{
User: user,
TaskProps: CompressProps{
Dirs: dirs,
Files: files,
Dst: dst,
},
}
record, err := Record(newTask)
if err != nil {
return nil, err
}
newTask.TaskModel = record
return newTask, nil
}
// NewCompressTaskFromModel 从数据库记录中恢复压缩任务
func NewCompressTaskFromModel(task *model.Task) (Job, error) {
user, err := model.GetUserByID(task.UserID)
if err != nil {
return nil, err
}
newTask := &CompressTask{
User: &user,
TaskModel: task,
}
err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps)
if err != nil {
return nil, err
}
return newTask, nil
}

@ -0,0 +1,8 @@
package task
import "errors"
var (
// ErrUnknownTaskType 未知任务类型
ErrUnknownTaskType = errors.New("未知任务类型")
)

@ -0,0 +1,98 @@
package task
import (
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/util"
)
// 任务类型
const (
// CompressTaskType 压缩任务
CompressTaskType = iota
)
// 任务状态
const (
// Queued 排队中
Queued = iota
// Processing 处理中
Processing
// Error 失败
Error
// Canceled 取消
Canceled
// Complete 完成
Complete
)
// 任务进度
const (
// Compressing 压缩中
CompressingProgress = iota
// Decompressing 解压缩中
DecompressingProgress
// Downloading 下载中
DownloadingProgress
// Transferring 转存中
TransferringProgress
)
// Job 任务接口
type Job interface {
Type() int // 返回任务类型
Creator() uint // 返回创建者ID
Props() string // 返回序列化后的任务属性
Model() *model.Task // 返回对应的数据库模型
SetStatus(int) // 设定任务状态
Do() // 开始执行任务
SetError(*JobError) // 设定任务失败信息
GetError() *JobError // 获取任务执行结果返回nil表示成功完成执行
}
// JobError 任务失败信息
type JobError struct {
Msg string
}
// Record 将任务记录到数据库中
func Record(job Job) (*model.Task, error) {
record := model.Task{
Status: Queued,
Type: job.Type(),
UserID: job.Creator(),
Progress: 0,
Error: "",
Props: job.Props(),
}
_, err := record.Create()
return &record, err
}
// Resume 从数据库中恢复未完成任务
func Resume() {
tasks := model.GetTasksByStatus(Queued)
if len(tasks) == 0 {
return
}
util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks))
for i := 0; i < len(tasks); i++ {
job, err := GetJobFromModel(&tasks[i])
if err != nil {
util.Log().Warning("无法恢复任务,%s", err)
continue
}
TaskPoll.Submit(job)
}
}
// GetJobFromModel 从数据库给定模型获取任务
func GetJobFromModel(task *model.Task) (Job, error) {
switch task.Type {
case CompressTaskType:
return NewCompressTaskFromModel(task)
default:
return nil, ErrUnknownTaskType
}
}

@ -1,124 +1,60 @@
package task package task
import "sync" import (
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/util"
)
// Pool 带有最大配额的goroutines任务池 // TaskPoll 要使用的任务池
var TaskPoll *Pool
// Pool 带有最大配额的任务池
type Pool struct { type Pool struct {
// 容量 // 容量
capacity int idleWorker chan int
// 初始容量
initialCapacity int
// 终止信号
terminateSignal chan error
// 全部任务完成的信号
finishSignal chan bool
// 有空余位置的信号
freeSignal chan bool
// 是否已关闭
closed bool
// 是否正在等待任务结束
waiting bool
// 互斥锁
lock sync.Mutex
// 等待队列
pending []Job
}
// Job 任务
type Job interface {
// 任务处理方法如果error不为nil
// 任务池会关闭并中止接受新任务
Do() error
} }
// NewGoroutinePool 创建一个容量为capacity的任务池 // Add 增加可用Worker数量
func NewGoroutinePool(capacity int) *Pool { func (pool *Pool) Add(num int) {
pool := &Pool{ for i := 0; i < num; i++ {
capacity: capacity, pool.idleWorker <- 1
initialCapacity: capacity,
terminateSignal: make(chan error),
finishSignal: make(chan bool),
freeSignal: make(chan bool),
} }
go pool.Schedule()
return pool
} }
// Schedule 为等待队列的任务分配Worker以及检测错误状态、所有任务完成 // ObtainWorker 阻塞直到获取新的Worker
func (pool *Pool) Schedule() { func (pool *Pool) ObtainWorker() Worker {
for { select {
select { case <-pool.idleWorker:
case <-pool.freeSignal: // 有空闲Worker名额时返回新Worker
// 有新的空余名额 return &GeneralWorker{}
pool.lock.Lock()
if len(pool.pending) > 0 {
// 有待处理的任务,开始处理
var job Job
job, pool.pending = pool.pending[0], pool.pending[1:]
go pool.start(job)
} else {
if pool.waiting && pool.capacity == pool.initialCapacity {
// 所有任务已结束
pool.lock.Unlock()
pool.finishSignal <- true
return
}
pool.lock.Unlock()
}
case <-pool.terminateSignal:
// 有任务意外中止,则发送完成信号
pool.finishSignal <- true
return
}
} }
} }
// Wait 等待队列中所有任务完成或有Job返回错误中止 // FreeWorker 添加空闲Worker
func (pool *Pool) Wait() chan bool { func (pool *Pool) FreeWorker() {
pool.lock.Lock() pool.Add(1)
pool.waiting = true
pool.lock.Unlock()
return pool.finishSignal
} }
// Submit 提交任务 // Submit 开始提交任务
func (pool *Pool) Submit(job Job) { func (pool *Pool) Submit(job Job) {
if pool.closed { go func() {
return util.Log().Debug("等待获取Worker")
} worker := pool.ObtainWorker()
util.Log().Debug("获取到Worker")
pool.lock.Lock() worker.Do(job)
if pool.capacity < 1 { util.Log().Debug("释放Worker")
// 容量为空时,加入等待队列 pool.FreeWorker()
pool.pending = append(pool.pending, job) }()
pool.lock.Unlock()
return
}
// 还有空闲容量时,开始执行任务
go pool.start(job)
} }
// 开始执行任务 // Init 初始化任务池
func (pool *Pool) start(job Job) { func Init() {
pool.capacity-- maxWorker := model.GetIntSetting("max_worker_num", 10)
pool.lock.Unlock() TaskPoll = &Pool{
idleWorker: make(chan int, maxWorker),
err := job.Do()
if err != nil {
pool.closed = true
select {
case <-pool.terminateSignal:
default:
close(pool.terminateSignal)
}
} }
TaskPoll.Add(maxWorker)
util.Log().Info("初始化任务队列WorkerNum = %d", maxWorker)
pool.lock.Lock() Resume()
pool.capacity++
pool.lock.Unlock()
pool.freeSignal <- true
} }

@ -0,0 +1,41 @@
package task
import "github.com/HFO4/cloudreve/pkg/util"
// Worker 处理任务的对象
type Worker interface {
Do(Job) // 执行任务
}
// GeneralWorker 通用Worker
type GeneralWorker struct {
}
// Do 执行任务
func (worker *GeneralWorker) Do(job Job) {
util.Log().Debug("开始执行任务")
job.SetStatus(Processing)
defer func() {
// 致命错误捕获
if err := recover(); err != nil {
util.Log().Debug("任务执行出错panic")
job.SetError(&JobError{Msg: "致命错误"})
job.SetStatus(Error)
}
}()
// 开始执行任务
job.Do()
// 任务执行失败
if err := job.GetError(); err != nil {
util.Log().Debug("任务执行出错")
job.SetStatus(Error)
return
}
util.Log().Debug("任务执行完成")
// 执行完成
job.SetStatus(Complete)
}

@ -47,6 +47,17 @@ func Archive(c *gin.Context) {
} }
} }
// Compress 创建文件压缩任务
func Compress(c *gin.Context) {
var service explorer.ItemCompressService
if err := c.ShouldBindJSON(&service); err == nil {
res := service.CreateCompressTask(c)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// AnonymousGetContent 匿名获取文件资源 // AnonymousGetContent 匿名获取文件资源
func AnonymousGetContent(c *gin.Context) { func AnonymousGetContent(c *gin.Context) {
// 创建上下文 // 创建上下文

@ -267,6 +267,8 @@ func InitMasterRouter() *gin.Engine {
file.GET("source/:id", controllers.GetSource) file.GET("source/:id", controllers.GetSource)
// 打包要下载的文件 // 打包要下载的文件
file.POST("archive", controllers.Archive) file.POST("archive", controllers.Archive)
// 创建文件压缩任务
file.POST("compress", controllers.Compress)
} }
// 目录 // 目录

@ -75,7 +75,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled { if fs.User.Group.OptionsSerialized.OneTimeDownload {
// 清理资源,删除临时文件 // 清理资源,删除临时文件
_ = cache.Deletes([]string{service.ID}, "archive_") _ = cache.Deletes([]string{service.ID}, "archive_")
} }
@ -205,7 +205,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
// 设置文件名 // 设置文件名
c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"")
if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled { if fs.User.Group.OptionsSerialized.OneTimeDownload {
// 清理资源,删除临时文件 // 清理资源,删除临时文件
_ = cache.Deletes([]string{service.ID}, "download_") _ = cache.Deletes([]string{service.ID}, "download_")
} }

@ -9,9 +9,12 @@ import (
"github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/task"
"github.com/HFO4/cloudreve/pkg/util" "github.com/HFO4/cloudreve/pkg/util"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"math"
"net/url" "net/url"
"path"
"time" "time"
) )
@ -34,6 +37,80 @@ type ItemService struct {
Dirs []uint `json:"dirs" binding:"exists"` Dirs []uint `json:"dirs" binding:"exists"`
} }
// ItemCompressService 文件压缩任务服务
type ItemCompressService struct {
Src ItemService `json:"src" binding:"exists"`
Dst string `json:"dst" binding:"required,min=1,max=65535"`
Name string `json:"name" binding:"required,min=1,max=255"`
}
// CreateCompressTask 创建文件压缩任务
func (service *ItemCompressService) CreateCompressTask(c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
defer fs.Recycle()
// 检查用户组权限
if !fs.User.Group.OptionsSerialized.ArchiveTask {
return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil)
}
// 存放目录是否存在,是否重名
if exist, _ := fs.IsPathExist(service.Dst); !exist {
return serializer.Err(serializer.CodeNotFound, "存放路径不存在", nil)
}
if exist, _ := fs.IsFileExist(path.Join(service.Dst, service.Name)); exist {
return serializer.ParamErr("名为 "+service.Name+" 的文件已存在", nil)
}
// 检查文件名合法性
if !fs.ValidateLegalName(context.Background(), service.Name) {
return serializer.ParamErr("文件名非法", nil)
}
if !fs.ValidateExtension(context.Background(), service.Name) {
return serializer.ParamErr("不允许存储此扩展名的文件", nil)
}
// 递归列出待压缩子目录
folders, err := model.GetRecursiveChildFolder(service.Src.Dirs, fs.User.ID, true)
if err != nil {
return serializer.Err(serializer.CodeDBError, "无法列出子目录", err)
}
// 列出所有待压缩文件
files, err := model.GetChildFilesOfFolders(&folders)
if err != nil {
return serializer.Err(serializer.CodeDBError, "无法列出子文件", err)
}
// 计算待压缩文件大小
var totalSize uint64
for i := 0; i < len(files); i++ {
totalSize += files[i].Size
}
// 按照平均压缩率计算用户空间是否足够
compressRatio := 0.4
spaceNeeded := uint64(math.Round(float64(totalSize) * compressRatio))
if fs.User.GetRemainingCapacity() < spaceNeeded {
return serializer.Err(serializer.CodeParamErr, "剩余空间不足", err)
}
// 创建任务
job, err := task.NewCompressTask(fs.User, path.Join(service.Dst, service.Name), service.Src.Dirs,
service.Src.Items)
if err != nil {
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
}
task.TaskPoll.Submit(job)
return serializer.Response{}
}
// Archive 创建归档 // Archive 创建归档
func (service *ItemService) Archive(ctx context.Context, c *gin.Context) serializer.Response { func (service *ItemService) Archive(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统 // 创建文件系统
@ -44,13 +121,13 @@ func (service *ItemService) Archive(ctx context.Context, c *gin.Context) seriali
defer fs.Recycle() defer fs.Recycle()
// 检查用户组权限 // 检查用户组权限
if !fs.User.Group.OptionsSerialized.ArchiveDownloadEnabled { if !fs.User.Group.OptionsSerialized.ArchiveDownload {
return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil) return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil)
} }
// 开始压缩 // 开始压缩
ctx = context.WithValue(ctx, fsctx.GinCtx, c) ctx = context.WithValue(ctx, fsctx.GinCtx, c)
zipFile, err := fs.Compress(ctx, service.Dirs, service.Items) zipFile, err := fs.Compress(ctx, service.Dirs, service.Items, true)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, "无法创建压缩文件", err) return serializer.Err(serializer.CodeNotSet, "无法创建压缩文件", err)
} }

@ -281,7 +281,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response {
user := userCtx.(*model.User) user := userCtx.(*model.User)
// 是否有权限 // 是否有权限
if !user.Group.OptionsSerialized.ArchiveDownloadEnabled { if !user.Group.OptionsSerialized.ArchiveDownload {
return serializer.Err(serializer.CodeNoPermissionErr, "您的用户组无权进行此操作", nil) return serializer.Err(serializer.CodeNoPermissionErr, "您的用户组无权进行此操作", nil)
} }
@ -310,7 +310,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response {
// 用于调下层service // 用于调下层service
tempUser := share.Creator() tempUser := share.Creator()
tempUser.Group.OptionsSerialized.ArchiveDownloadEnabled = true tempUser.Group.OptionsSerialized.ArchiveDownload = true
c.Set("user", tempUser) c.Set("user", tempUser)
subService := explorer.ItemService{ subService := explorer.ItemService{

Loading…
Cancel
Save