diff --git a/main.go b/main.go index 9cdc9c0..710b430 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github.com/HFO4/cloudreve/pkg/authn" "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/conf" + "github.com/HFO4/cloudreve/pkg/task" "github.com/HFO4/cloudreve/routers" "github.com/gin-gonic/gin" ) @@ -20,6 +21,7 @@ func init() { if conf.SystemConfig.Mode == "master" { model.Init() authn.Init() + task.Init() } auth.Init() } diff --git a/middleware/share_test.go b/middleware/share_test.go index c61f86d..16ae822 100644 --- a/middleware/share_test.go +++ b/middleware/share_test.go @@ -136,7 +136,7 @@ func TestBeforeShareDownload(t *testing.T) { c.Set("user", &model.User{ Model: gorm.Model{ID: 1}, Group: model.Group{OptionsSerialized: model.GroupOption{ - ShareDownloadEnabled: true, + ShareDownload: true, }}, }) testFunc(c) diff --git a/models/group.go b/models/group.go index 9f4af46..bf2ffce 100644 --- a/models/group.go +++ b/models/group.go @@ -16,7 +16,7 @@ type Group struct { Aria2Option string Color string SpeedLimit int - Options string `json:"-",gorm:"size:4096"` + Options string `json:"-",gorm:"type:text"` // 数据库忽略字段 PolicyList []uint `gorm:"-"` @@ -25,11 +25,11 @@ type Group struct { // GroupOption 用户组其他配置 type GroupOption struct { - ArchiveDownloadEnabled bool `json:"archive_download,omitempty"` - ArchiveTaskEnabled bool `json:"archive_task,omitempty"` - OneTimeDownloadEnabled bool `json:"one_time_download,omitempty"` - ShareDownloadEnabled bool `json:"share_download,omitempty"` - ShareFreeEnabled bool `json:"share_free,omitempty"` + ArchiveDownload bool `json:"archive_download,omitempty"` + ArchiveTask bool `json:"archive_task,omitempty"` + OneTimeDownload bool `json:"one_time_download,omitempty"` + ShareDownload bool `json:"share_download,omitempty"` + ShareFree bool `json:"share_free,omitempty"` } // GetAria2Option 获取用户离线下载设备 diff --git a/models/init.go b/models/init.go index c782745..bafc977 100644 --- a/models/init.go +++ b/models/init.go @@ -47,7 +47,7 @@ func Init() { // Debug模式下,输出所有 SQL 日志 if conf.SystemConfig.Debug { - db.LogMode(true) + db.LogMode(false) } //db.SetLogger(util.Log()) diff --git a/models/migration.go b/models/migration.go index a09d2a6..d025c2f 100644 --- a/models/migration.go +++ b/models/migration.go @@ -29,7 +29,7 @@ func migration() { if conf.DatabaseConfig.Type == "mysql" { DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") } - DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}) + DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}, &Task{}) // 创建初始存储策略 addDefaultPolicy() @@ -159,7 +159,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "aria2_token", Value: `your token`, Type: "aria2"}, {Name: "aria2_rpcurl", Value: `http://127.0.0.1:6800/`, Type: "aria2"}, {Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"}, - {Name: "task_queue_token", Value: ``, Type: "task"}, + {Name: "max_worker_num", Value: `10`, Type: "task"}, {Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"}, {Name: "temp_path", Value: "temp", Type: "path"}, {Name: "score_enabled", Value: "1", Type: "score"}, @@ -186,9 +186,9 @@ func addDefaultGroups() { WebDAVEnabled: true, Aria2Option: "0,0,0", OptionsSerialized: GroupOption{ - ArchiveDownloadEnabled: true, - ArchiveTaskEnabled: true, - ShareDownloadEnabled: true, + ArchiveDownload: true, + ArchiveTask: true, + ShareDownload: true, }, } if err := DB.Create(&defaultAdminGroup).Error; err != nil { diff --git a/models/policy.go b/models/policy.go index c78e157..0c36435 100644 --- a/models/policy.go +++ b/models/policy.go @@ -22,14 +22,14 @@ type Policy struct { BucketName string IsPrivate bool BaseURL string - AccessKey string `gorm:"size:1024"` - SecretKey string `gorm:"size:1024"` + AccessKey string `gorm:"type:text"` + SecretKey string `gorm:"type:text"` MaxSize uint64 AutoRename bool DirNameRule string FileNameRule string IsOriginLinkEnable bool - Options string `gorm:"size:4096"` + Options string `gorm:"type:text"` // 数据库忽略字段 OptionsSerialized PolicyOption `gorm:"-"` diff --git a/models/share.go b/models/share.go index 2df25e3..9dcde9d 100644 --- a/models/share.go +++ b/models/share.go @@ -123,7 +123,7 @@ func (share *Share) SourceFile() *File { // CanBeDownloadBy 返回此分享是否可以被给定用户下载 func (share *Share) CanBeDownloadBy(user *User) error { // 用户组权限 - if !user.Group.OptionsSerialized.ShareDownloadEnabled { + if !user.Group.OptionsSerialized.ShareDownload { if user.IsAnonymous() { return errors.New("未登录用户无法下载") } @@ -169,7 +169,7 @@ func (share *Share) DownloadBy(user *User, c *gin.Context) error { // Purchase 使用积分购买分享 func (share *Share) Purchase(user *User) error { // 不需要付积分 - if share.Score == 0 || user.Group.OptionsSerialized.ShareFreeEnabled || user.ID == share.UserID { + if share.Score == 0 || user.Group.OptionsSerialized.ShareFree || user.ID == share.UserID { return nil } diff --git a/models/share_test.go b/models/share_test.go index de4faf9..e390a29 100644 --- a/models/share_test.go +++ b/models/share_test.go @@ -156,7 +156,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) { user := &User{ Group: Group{ OptionsSerialized: GroupOption{ - ShareDownloadEnabled: false, + ShareDownload: false, }, }, } @@ -169,7 +169,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) { Model: gorm.Model{ID: 1}, Group: Group{ OptionsSerialized: GroupOption{ - ShareDownloadEnabled: false, + ShareDownload: false, }, }, } @@ -181,7 +181,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) { user := &User{ Group: Group{ OptionsSerialized: GroupOption{ - ShareDownloadEnabled: true, + ShareDownload: true, }, }, } @@ -195,7 +195,7 @@ func TestShare_CanBeDownloadBy(t *testing.T) { Model: gorm.Model{ID: 1}, Group: Group{ OptionsSerialized: GroupOption{ - ShareDownloadEnabled: true, + ShareDownload: true, }, }, } @@ -259,10 +259,10 @@ func TestShare_Purchase(t *testing.T) { asserts.NoError(share.Purchase(&user)) share.Score = 1 - user.Group.OptionsSerialized.ShareFreeEnabled = true + user.Group.OptionsSerialized.ShareFree = true asserts.NoError(share.Purchase(&user)) - user.Group.OptionsSerialized.ShareFreeEnabled = false + user.Group.OptionsSerialized.ShareFree = false share.UserID = 1 user.ID = 1 asserts.NoError(share.Purchase(&user)) diff --git a/models/task.go b/models/task.go new file mode 100644 index 0000000..400fc91 --- /dev/null +++ b/models/task.go @@ -0,0 +1,48 @@ +package model + +import ( + "github.com/HFO4/cloudreve/pkg/util" + "github.com/jinzhu/gorm" +) + +// Task 任务模型 +type Task struct { + gorm.Model + Status int // 任务状态 + Type int // 任务类型 + UserID uint // 发起者UID,0表示为系统发起 + Progress int // 进度 + Error string // 错误信息 + Props string `gorm:"type:text"` // 任务属性 +} + +// Create 创建任务记录 +func (task *Task) Create() (uint, error) { + if err := DB.Create(task).Error; err != nil { + util.Log().Warning("无法插入任务记录, %s", err) + return 0, err + } + return task.ID, nil +} + +// SetStatus 设定任务状态 +func (task *Task) SetStatus(status int) error { + return DB.Model(task).Select("status").Updates(map[string]interface{}{"status": status}).Error +} + +// SetProgress 设定任务进度 +func (task *Task) SetProgress(progress int) error { + return DB.Model(task).Select("progress").Updates(map[string]interface{}{"progress": progress}).Error +} + +// SetError 设定错误信息 +func (task *Task) SetError(err string) error { + return DB.Model(task).Select("error").Updates(map[string]interface{}{"error": err}).Error +} + +// GetTasksByStatus 根据状态检索任务 +func GetTasksByStatus(status int) []Task { + var tasks []Task + DB.Where("status = ?", status).Find(&tasks) + return tasks +} diff --git a/models/user.go b/models/user.go index 710a043..aef7d08 100644 --- a/models/user.go +++ b/models/user.go @@ -37,8 +37,8 @@ type User struct { TwoFactor string `json:"-"` Delay int Avatar string - Options string `json:"-",gorm:"size:4096"` - Authn string `gorm:"size:8192"` + Options string `json:"-",gorm:"type:text"` + Authn string `gorm:"type:text"` Score int // 关联模型 diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 843d425..18a9965 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -20,7 +20,7 @@ import ( */ // Compress 创建给定目录和文件的压缩文件 -func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) (string, error) { +func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint, isArchive bool) (string, error) { // 查找待压缩目录 folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID) if err != nil && len(folders) != 0 { @@ -66,8 +66,13 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( } // 创建临时压缩文件 + saveFolder := "archive" + if !isArchive { + saveFolder = "compress" + } zipFilePath := filepath.Join( model.GetSettingByName("temp_path"), + saveFolder, fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()), ) zipFile, err := util.CreatNestedFile(zipFilePath) @@ -91,7 +96,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) return "", ErrClientCanceled default: - fs.doCompress(ctx, nil, &folders[i], zipWriter, true) + fs.doCompress(ctx, nil, &folders[i], zipWriter, isArchive) } } @@ -102,7 +107,7 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) return "", ErrClientCanceled default: - fs.doCompress(ctx, &files[i], nil, zipWriter, true) + fs.doCompress(ctx, &files[i], nil, zipWriter, isArchive) } } diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go index e814213..2f28cc3 100644 --- a/pkg/filesystem/archive_test.go +++ b/pkg/filesystem/archive_test.go @@ -51,7 +51,7 @@ func TestFileSystem_Compress(t *testing.T) { // 查找上传策略 asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1)) - zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) + zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true) asserts.NoError(err) asserts.NotEmpty(zipFile) asserts.Contains(zipFile, "archive_") @@ -76,7 +76,7 @@ func TestFileSystem_Compress(t *testing.T) { ) asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) + zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true) asserts.Error(err) asserts.Empty(zipFile) } @@ -100,7 +100,7 @@ func TestFileSystem_Compress(t *testing.T) { ) asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}) + zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true) asserts.Error(err) asserts.Equal(ErrObjectNotExist, err) asserts.Empty(zipFile) diff --git a/pkg/filesystem/image_test.go b/pkg/filesystem/image_test.go deleted file mode 100644 index 23bdf15..0000000 --- a/pkg/filesystem/image_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "fmt" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/HFO4/cloudreve/models" - "github.com/HFO4/cloudreve/pkg/cache" - "github.com/HFO4/cloudreve/pkg/conf" - "github.com/HFO4/cloudreve/pkg/filesystem/response" - "github.com/HFO4/cloudreve/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "image" - "image/jpeg" - "os" - "testing" -) - -func CreateTestImage() *os.File { - file, err := os.Create("TestFileSystem_GenerateThumbnail.jpeg") - alpha := image.NewAlpha(image.Rect(0, 0, 500, 200)) - jpeg.Encode(file, alpha, nil) - if err != nil { - fmt.Println(err) - } - _, _ = file.Seek(0, 0) - return file -} - -func TestFileSystem_GetThumb(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - cache.Set("setting_preview_timeout", "60", 0) - - // 正常 - { - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - } - testHandler := new(FileHeaderMock) - testHandler.On("Thumb", testMock.Anything, "123.jpg").Return(&response.ContentResponse{URL: "123"}, nil) - fs.Handler = testHandler - mock.ExpectQuery("SELECT(.+)"). - WithArgs(10, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "pic_info", "source_name", "policy_id"}). - AddRow(10, "10,10", "123.jpg", 154), - ) - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "type"}). - AddRow(154, "mock"), - ) - - res, err := fs.GetThumb(ctx, 10) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.NoError(err) - asserts.Equal("123", res.URL) - } - - // 文件不存在 - { - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectQuery("SELECT(.+)"). - WithArgs(10, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "pic_info", "source_name"}), - ) - - _, err := fs.GetThumb(ctx, 10) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestFileSystem_GenerateThumbnail(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - } - ctx := context.Background() - - // 成功 - { - src := CreateTestImage() - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil) - fs.Handler = testHandler - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - file := &model.File{ - Model: gorm.Model{ID: 1}, - Name: "123.jpg", - SourceName: "TestFileSystem_GenerateThumbnail.jpeg", - } - - fs.GenerateThumbnail(ctx, file) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.True(util.Exists("TestFileSystem_GenerateThumbnail.jpeg" + conf.ThumbConfig.FileSuffix)) - - } - - // 成功,不进行数据库更新 - { - src := CreateTestImage() - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil) - fs.Handler = testHandler - - file := &model.File{ - Name: "123.jpg", - SourceName: "TestFileSystem_GenerateThumbnail.jpeg", - } - - fs.GenerateThumbnail(ctx, file) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.True(util.Exists("TestFileSystem_GenerateThumbnail.jpeg" + conf.ThumbConfig.FileSuffix)) - - } - - // 更新信息失败后删除文件 - { - src := CreateTestImage() - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "TestFileSystem_GenerateThumbnail.jpeg").Return(src, nil) - testHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, nil) - fs.Handler = testHandler - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - - file := &model.File{ - Model: gorm.Model{ID: 1}, - Name: "123.jpg", - SourceName: "TestFileSystem_GenerateThumbnail.jpeg", - } - - fs.GenerateThumbnail(ctx, file) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - - } - - // 不能生成缩略图 - { - file := &model.File{ - Model: gorm.Model{ID: 1}, - Name: "123.123", - SourceName: "TestFileSystem_GenerateThumbnail.jpeg", - } - - fs.GenerateThumbnail(ctx, file) - asserts.NoError(mock.ExpectationsWereMet()) - } - -} - -func TestFileSystem_GenerateThumbnailSize(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - } - asserts.NotPanics(func() { - _, _ = fs.GenerateThumbnailSize(0, 0) - }) - -} diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index c4c880e..96eb702 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -9,6 +9,7 @@ import ( "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" + "os" "path" ) @@ -114,8 +115,10 @@ func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file FileHe var reqContext context.Context if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok { reqContext = ginCtx.Request.Context() + } else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok { + reqContext = reqCtx } else { - reqContext = ctx.Value(fsctx.HTTPCtx).(context.Context) + return } select { @@ -182,3 +185,42 @@ func (fs *FileSystem) GetUploadToken(ctx context.Context, path string, size uint return &credential, nil } + +// UploadFromPath 将本机已有文件上传到用户的文件系统 +func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string) error { + file, err := os.Open(src) + if err != nil { + return err + } + defer file.Close() + + // 获取源文件大小 + fi, err := file.Stat() + if err != nil { + return err + } + size := fi.Size() + + // 构建文件头 + fileName := path.Base(dst) + filePath := path.Dir(dst) + fileData := local.FileStream{ + File: file, + Size: uint64(size), + Name: fileName, + VirtualPath: filePath, + } + + // 给文件系统分配钩子 + fs.Use("BeforeUpload", HookValidateFile) + fs.Use("BeforeUpload", HookValidateCapacity) + fs.Use("AfterUploadCanceled", HookDeleteTempFile) + fs.Use("AfterUploadCanceled", HookGiveBackCapacity) + fs.Use("AfterUpload", GenericAfterUpload) + fs.Use("AfterValidateFailed", HookDeleteTempFile) + fs.Use("AfterValidateFailed", HookGiveBackCapacity) + fs.Use("AfterUploadFailed", HookGiveBackCapacity) + + // 开始上传 + return fs.Upload(ctx, fileData) +} diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go index 396188b..0f11e04 100644 --- a/pkg/serializer/user.go +++ b/pkg/serializer/user.go @@ -77,9 +77,9 @@ func BuildUser(user model.User) User { AllowShare: user.Group.ShareEnabled, AllowRemoteDownload: aria2Option[0], AllowTorrentDownload: aria2Option[2], - AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownloadEnabled, - ShareFreeEnabled: user.Group.OptionsSerialized.ShareFreeEnabled, - ShareDownload: user.Group.OptionsSerialized.ShareDownloadEnabled, + AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownload, + ShareFreeEnabled: user.Group.OptionsSerialized.ShareFree, + ShareDownload: user.Group.OptionsSerialized.ShareDownload, }, } } diff --git a/pkg/task/compress.go b/pkg/task/compress.go new file mode 100644 index 0000000..52e4670 --- /dev/null +++ b/pkg/task/compress.go @@ -0,0 +1,155 @@ +package task + +import ( + "context" + "encoding/json" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem" + "github.com/HFO4/cloudreve/pkg/util" + "os" +) + +// CompressTask 文件压缩任务 +type CompressTask struct { + User *model.User + TaskModel *model.Task + TaskProps CompressProps + Err *JobError + + zipPath string +} + +// CompressProps 压缩任务属性 +type CompressProps struct { + Dirs []uint `json:"dirs"` + Files []uint `json:"files"` + Dst string `json:"dst"` +} + +// Props 获取任务属性 +func (job *CompressTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *CompressTask) Type() int { + return CompressTaskType +} + +// Creator 获取创建者ID +func (job *CompressTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *CompressTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *CompressTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *CompressTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) + + // 删除压缩文件 + job.removeZipFile() +} + +func (job *CompressTask) removeZipFile() { + if job.zipPath != "" { + if err := os.Remove(job.zipPath); err != nil { + util.Log().Warning("无法删除临时压缩文件 %s , %s", job.zipPath, err) + } + } +} + +// SetErrorMsg 设定任务失败信息 +func (job *CompressTask) SetErrorMsg(msg string) { + job.SetError(&JobError{Msg: msg}) +} + +// GetError 返回任务失败信息 +func (job *CompressTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *CompressTask) Do() { + // 创建文件系统 + fs, err := filesystem.NewFileSystem(job.User) + if err != nil { + job.SetErrorMsg(err.Error()) + return + } + defer fs.Recycle() + + util.Log().Debug("开始压缩文件") + job.TaskModel.SetProgress(CompressingProgress) + + // 开始压缩 + ctx := context.Background() + zipFile, err := fs.Compress(ctx, job.TaskProps.Dirs, job.TaskProps.Files, false) + if err != nil { + job.SetErrorMsg(err.Error()) + return + } + job.zipPath = zipFile + + util.Log().Debug("压缩文件存放至%s,开始上传", zipFile) + job.TaskModel.SetProgress(TransferringProgress) + + // 上传文件 + err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst) + if err != nil { + job.SetErrorMsg(err.Error()) + return + } + + job.removeZipFile() +} + +// NewCompressTask 新建压缩任务 +func NewCompressTask(user *model.User, dst string, dirs, files []uint) (Job, error) { + newTask := &CompressTask{ + User: user, + TaskProps: CompressProps{ + Dirs: dirs, + Files: files, + Dst: dst, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewCompressTaskFromModel 从数据库记录中恢复压缩任务 +func NewCompressTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &CompressTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/task/errors.go b/pkg/task/errors.go new file mode 100644 index 0000000..ad9df0c --- /dev/null +++ b/pkg/task/errors.go @@ -0,0 +1,8 @@ +package task + +import "errors" + +var ( + // ErrUnknownTaskType 未知任务类型 + ErrUnknownTaskType = errors.New("未知任务类型") +) diff --git a/pkg/task/job.go b/pkg/task/job.go new file mode 100644 index 0000000..057483d --- /dev/null +++ b/pkg/task/job.go @@ -0,0 +1,98 @@ +package task + +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/util" +) + +// 任务类型 +const ( + // CompressTaskType 压缩任务 + CompressTaskType = iota +) + +// 任务状态 +const ( + // Queued 排队中 + Queued = iota + // Processing 处理中 + Processing + // Error 失败 + Error + // Canceled 取消 + Canceled + // Complete 完成 + Complete +) + +// 任务进度 +const ( + // Compressing 压缩中 + CompressingProgress = iota + // Decompressing 解压缩中 + DecompressingProgress + // Downloading 下载中 + DownloadingProgress + // Transferring 转存中 + TransferringProgress +) + +// Job 任务接口 +type Job interface { + Type() int // 返回任务类型 + Creator() uint // 返回创建者ID + Props() string // 返回序列化后的任务属性 + Model() *model.Task // 返回对应的数据库模型 + SetStatus(int) // 设定任务状态 + Do() // 开始执行任务 + SetError(*JobError) // 设定任务失败信息 + GetError() *JobError // 获取任务执行结果,返回nil表示成功完成执行 +} + +// JobError 任务失败信息 +type JobError struct { + Msg string +} + +// Record 将任务记录到数据库中 +func Record(job Job) (*model.Task, error) { + record := model.Task{ + Status: Queued, + Type: job.Type(), + UserID: job.Creator(), + Progress: 0, + Error: "", + Props: job.Props(), + } + _, err := record.Create() + return &record, err +} + +// Resume 从数据库中恢复未完成任务 +func Resume() { + tasks := model.GetTasksByStatus(Queued) + if len(tasks) == 0 { + return + } + util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks)) + + for i := 0; i < len(tasks); i++ { + job, err := GetJobFromModel(&tasks[i]) + if err != nil { + util.Log().Warning("无法恢复任务,%s", err) + continue + } + + TaskPoll.Submit(job) + } +} + +// GetJobFromModel 从数据库给定模型获取任务 +func GetJobFromModel(task *model.Task) (Job, error) { + switch task.Type { + case CompressTaskType: + return NewCompressTaskFromModel(task) + default: + return nil, ErrUnknownTaskType + } +} diff --git a/pkg/task/pool.go b/pkg/task/pool.go index 6b62f01..4982f90 100644 --- a/pkg/task/pool.go +++ b/pkg/task/pool.go @@ -1,124 +1,60 @@ package task -import "sync" +import ( + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/util" +) -// Pool 带有最大配额的goroutines任务池 +// TaskPoll 要使用的任务池 +var TaskPoll *Pool + +// Pool 带有最大配额的任务池 type Pool struct { // 容量 - capacity int - // 初始容量 - initialCapacity int - - // 终止信号 - terminateSignal chan error - // 全部任务完成的信号 - finishSignal chan bool - // 有空余位置的信号 - freeSignal chan bool - - // 是否已关闭 - closed bool - // 是否正在等待任务结束 - waiting bool - - // 互斥锁 - lock sync.Mutex - // 等待队列 - pending []Job -} - -// Job 任务 -type Job interface { - // 任务处理方法,如果error不为nil, - // 任务池会关闭并中止接受新任务 - Do() error + idleWorker chan int } -// NewGoroutinePool 创建一个容量为capacity的任务池 -func NewGoroutinePool(capacity int) *Pool { - pool := &Pool{ - capacity: capacity, - initialCapacity: capacity, - terminateSignal: make(chan error), - finishSignal: make(chan bool), - freeSignal: make(chan bool), +// Add 增加可用Worker数量 +func (pool *Pool) Add(num int) { + for i := 0; i < num; i++ { + pool.idleWorker <- 1 } - go pool.Schedule() - return pool } -// Schedule 为等待队列的任务分配Worker,以及检测错误状态、所有任务完成 -func (pool *Pool) Schedule() { - for { - select { - case <-pool.freeSignal: - // 有新的空余名额 - pool.lock.Lock() - if len(pool.pending) > 0 { - // 有待处理的任务,开始处理 - var job Job - job, pool.pending = pool.pending[0], pool.pending[1:] - go pool.start(job) - } else { - if pool.waiting && pool.capacity == pool.initialCapacity { - // 所有任务已结束 - pool.lock.Unlock() - pool.finishSignal <- true - return - } - pool.lock.Unlock() - } - case <-pool.terminateSignal: - // 有任务意外中止,则发送完成信号 - pool.finishSignal <- true - return - } +// ObtainWorker 阻塞直到获取新的Worker +func (pool *Pool) ObtainWorker() Worker { + select { + case <-pool.idleWorker: + // 有空闲Worker名额时,返回新Worker + return &GeneralWorker{} } } -// Wait 等待队列中所有任务完成或有Job返回错误中止 -func (pool *Pool) Wait() chan bool { - pool.lock.Lock() - pool.waiting = true - pool.lock.Unlock() - return pool.finishSignal +// FreeWorker 添加空闲Worker +func (pool *Pool) FreeWorker() { + pool.Add(1) } -// Submit 提交新任务 +// Submit 开始提交任务 func (pool *Pool) Submit(job Job) { - if pool.closed { - return - } - - pool.lock.Lock() - if pool.capacity < 1 { - // 容量为空时,加入等待队列 - pool.pending = append(pool.pending, job) - pool.lock.Unlock() - return - } - - // 还有空闲容量时,开始执行任务 - go pool.start(job) + go func() { + util.Log().Debug("等待获取Worker") + worker := pool.ObtainWorker() + util.Log().Debug("获取到Worker") + worker.Do(job) + util.Log().Debug("释放Worker") + pool.FreeWorker() + }() } -// 开始执行任务 -func (pool *Pool) start(job Job) { - pool.capacity-- - pool.lock.Unlock() - - err := job.Do() - if err != nil { - pool.closed = true - select { - case <-pool.terminateSignal: - default: - close(pool.terminateSignal) - } +// Init 初始化任务池 +func Init() { + maxWorker := model.GetIntSetting("max_worker_num", 10) + TaskPoll = &Pool{ + idleWorker: make(chan int, maxWorker), } + TaskPoll.Add(maxWorker) + util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker) - pool.lock.Lock() - pool.capacity++ - pool.lock.Unlock() - pool.freeSignal <- true + Resume() } diff --git a/pkg/task/worker.go b/pkg/task/worker.go new file mode 100644 index 0000000..9e53c77 --- /dev/null +++ b/pkg/task/worker.go @@ -0,0 +1,41 @@ +package task + +import "github.com/HFO4/cloudreve/pkg/util" + +// Worker 处理任务的对象 +type Worker interface { + Do(Job) // 执行任务 +} + +// GeneralWorker 通用Worker +type GeneralWorker struct { +} + +// Do 执行任务 +func (worker *GeneralWorker) Do(job Job) { + util.Log().Debug("开始执行任务") + job.SetStatus(Processing) + + defer func() { + // 致命错误捕获 + if err := recover(); err != nil { + util.Log().Debug("任务执行出错,panic") + job.SetError(&JobError{Msg: "致命错误"}) + job.SetStatus(Error) + } + }() + + // 开始执行任务 + job.Do() + + // 任务执行失败 + if err := job.GetError(); err != nil { + util.Log().Debug("任务执行出错") + job.SetStatus(Error) + return + } + + util.Log().Debug("任务执行完成") + // 执行完成 + job.SetStatus(Complete) +} diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 0b6ff50..b059b94 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -47,6 +47,17 @@ func Archive(c *gin.Context) { } } +// Compress 创建文件压缩任务 +func Compress(c *gin.Context) { + var service explorer.ItemCompressService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.CreateCompressTask(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // AnonymousGetContent 匿名获取文件资源 func AnonymousGetContent(c *gin.Context) { // 创建上下文 diff --git a/routers/router.go b/routers/router.go index 324f99c..f846335 100644 --- a/routers/router.go +++ b/routers/router.go @@ -267,6 +267,8 @@ func InitMasterRouter() *gin.Engine { file.GET("source/:id", controllers.GetSource) // 打包要下载的文件 file.POST("archive", controllers.Archive) + // 创建文件压缩任务 + file.POST("compress", controllers.Compress) } // 目录 diff --git a/service/explorer/file.go b/service/explorer/file.go index 9d0da1d..ea38c2c 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -75,7 +75,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con return serializer.Err(serializer.CodeNotSet, err.Error(), err) } - if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled { + if fs.User.Group.OptionsSerialized.OneTimeDownload { // 清理资源,删除临时文件 _ = cache.Deletes([]string{service.ID}, "archive_") } @@ -205,7 +205,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se // 设置文件名 c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"") - if fs.User.Group.OptionsSerialized.OneTimeDownloadEnabled { + if fs.User.Group.OptionsSerialized.OneTimeDownload { // 清理资源,删除临时文件 _ = cache.Deletes([]string{service.ID}, "download_") } diff --git a/service/explorer/objects.go b/service/explorer/objects.go index 5fa9782..4080593 100644 --- a/service/explorer/objects.go +++ b/service/explorer/objects.go @@ -9,9 +9,12 @@ import ( "github.com/HFO4/cloudreve/pkg/filesystem" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/serializer" + "github.com/HFO4/cloudreve/pkg/task" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" + "math" "net/url" + "path" "time" ) @@ -34,6 +37,80 @@ type ItemService struct { Dirs []uint `json:"dirs" binding:"exists"` } +// ItemCompressService 文件压缩任务服务 +type ItemCompressService struct { + Src ItemService `json:"src" binding:"exists"` + Dst string `json:"dst" binding:"required,min=1,max=65535"` + Name string `json:"name" binding:"required,min=1,max=255"` +} + +// CreateCompressTask 创建文件压缩任务 +func (service *ItemCompressService) CreateCompressTask(c *gin.Context) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewFileSystemFromContext(c) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + // 检查用户组权限 + if !fs.User.Group.OptionsSerialized.ArchiveTask { + return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil) + } + + // 存放目录是否存在,是否重名 + if exist, _ := fs.IsPathExist(service.Dst); !exist { + return serializer.Err(serializer.CodeNotFound, "存放路径不存在", nil) + } + if exist, _ := fs.IsFileExist(path.Join(service.Dst, service.Name)); exist { + return serializer.ParamErr("名为 "+service.Name+" 的文件已存在", nil) + } + + // 检查文件名合法性 + if !fs.ValidateLegalName(context.Background(), service.Name) { + return serializer.ParamErr("文件名非法", nil) + } + if !fs.ValidateExtension(context.Background(), service.Name) { + return serializer.ParamErr("不允许存储此扩展名的文件", nil) + } + + // 递归列出待压缩子目录 + folders, err := model.GetRecursiveChildFolder(service.Src.Dirs, fs.User.ID, true) + if err != nil { + return serializer.Err(serializer.CodeDBError, "无法列出子目录", err) + } + + // 列出所有待压缩文件 + files, err := model.GetChildFilesOfFolders(&folders) + if err != nil { + return serializer.Err(serializer.CodeDBError, "无法列出子文件", err) + } + + // 计算待压缩文件大小 + var totalSize uint64 + for i := 0; i < len(files); i++ { + totalSize += files[i].Size + } + + // 按照平均压缩率计算用户空间是否足够 + compressRatio := 0.4 + spaceNeeded := uint64(math.Round(float64(totalSize) * compressRatio)) + if fs.User.GetRemainingCapacity() < spaceNeeded { + return serializer.Err(serializer.CodeParamErr, "剩余空间不足", err) + } + + // 创建任务 + job, err := task.NewCompressTask(fs.User, path.Join(service.Dst, service.Name), service.Src.Dirs, + service.Src.Items) + if err != nil { + return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) + } + + task.TaskPoll.Submit(job) + return serializer.Response{} + +} + // Archive 创建归档 func (service *ItemService) Archive(ctx context.Context, c *gin.Context) serializer.Response { // 创建文件系统 @@ -44,13 +121,13 @@ func (service *ItemService) Archive(ctx context.Context, c *gin.Context) seriali defer fs.Recycle() // 检查用户组权限 - if !fs.User.Group.OptionsSerialized.ArchiveDownloadEnabled { + if !fs.User.Group.OptionsSerialized.ArchiveDownload { return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil) } // 开始压缩 ctx = context.WithValue(ctx, fsctx.GinCtx, c) - zipFile, err := fs.Compress(ctx, service.Dirs, service.Items) + zipFile, err := fs.Compress(ctx, service.Dirs, service.Items, true) if err != nil { return serializer.Err(serializer.CodeNotSet, "无法创建压缩文件", err) } diff --git a/service/share/visit.go b/service/share/visit.go index ebef557..6d54cc1 100644 --- a/service/share/visit.go +++ b/service/share/visit.go @@ -281,7 +281,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response { user := userCtx.(*model.User) // 是否有权限 - if !user.Group.OptionsSerialized.ArchiveDownloadEnabled { + if !user.Group.OptionsSerialized.ArchiveDownload { return serializer.Err(serializer.CodeNoPermissionErr, "您的用户组无权进行此操作", nil) } @@ -310,7 +310,7 @@ func (service *ArchiveService) Archive(c *gin.Context) serializer.Response { // 用于调下层service tempUser := share.Creator() - tempUser.Group.OptionsSerialized.ArchiveDownloadEnabled = true + tempUser.Group.OptionsSerialized.ArchiveDownload = true c.Set("user", tempUser) subService := explorer.ItemService{