Merge branch 'master' into patch-samesite

pull/1381/head
AaronLiu 3 years ago committed by GitHub
commit 4ad9649300
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,10 +7,10 @@ jobs:
name: Build name: Build
runs-on: ubuntu-18.04 runs-on: ubuntu-18.04
steps: steps:
- name: Set up Go 1.17 - name: Set up Go 1.18
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: "1.17" go-version: "1.18"
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

@ -12,10 +12,10 @@ jobs:
name: Test name: Test
runs-on: ubuntu-18.04 runs-on: ubuntu-18.04
steps: steps:
- name: Set up Go 1.17 - name: Set up Go 1.18
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: "1.17" go-version: "1.18"
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

@ -1,6 +1,6 @@
language: go language: go
go: go:
- 1.17.x - 1.18.x
node_js: "12.16.3" node_js: "12.16.3"
git: git:
depth: 1 depth: 1

@ -1,42 +1,47 @@
FROM golang:1.17-alpine as cloudreve_builder # the frontend builder
# cloudreve need node.js 16* to build frontend,
# separate build step and custom image tag will resolve this
FROM node:16-alpine as cloudreve_frontend_builder
RUN apk update \
&& apk add --no-cache wget curl git yarn zip bash \
&& git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_frontend
# install dependencies and build tools # build frontend assets using build script, make sure all the steps just follow the regular release
RUN apk update && apk add --no-cache wget curl git yarn build-base gcc abuild binutils binutils-doc gcc-doc zip WORKDIR /cloudreve_frontend
ENV GENERATE_SOURCEMAP false
RUN chmod +x ./build.sh && ./build.sh -a
WORKDIR /cloudreve_builder
RUN git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git
# build frontend # the backend builder
WORKDIR /cloudreve_builder/Cloudreve/assets # cloudreve backend needs golang 1.18* to build
ENV GENERATE_SOURCEMAP false FROM golang:1.18-alpine as cloudreve_backend_builder
RUN yarn install --network-timeout 1000000 # install dependencies and build tools
RUN yarn run build RUN apk update \
# install dependencies and build tools
&& apk add --no-cache wget curl git build-base gcc abuild binutils binutils-doc gcc-doc zip bash \
&& git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_backend
# build backend WORKDIR /cloudreve_backend
WORKDIR /cloudreve_builder/Cloudreve COPY --from=cloudreve_frontend_builder /cloudreve_frontend/assets.zip ./
RUN zip -r - assets/build >assets.zip RUN chmod +x ./build.sh && ./build.sh -c
RUN tag_name=$(git describe --tags) \
&& export COMMIT_SHA=$(git rev-parse --short HEAD) \
&& go build -a -o cloudreve -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$tag_name' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'"
# build final image # TODO: merge the frontend build and backend build into a single one image
# the final published image
FROM alpine:latest FROM alpine:latest
WORKDIR /cloudreve WORKDIR /cloudreve
COPY --from=cloudreve_backend_builder /cloudreve_backend/cloudreve ./cloudreve
RUN apk update && apk add --no-cache tzdata
RUN apk update \
# we using the `Asia/Shanghai` timezone by default, you can do modification at your will && apk add --no-cache tzdata \
RUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone && echo "Asia/Shanghai" > /etc/timezone \
&& chmod +x ./cloudreve \
COPY --from=cloudreve_builder /cloudreve_builder/Cloudreve/cloudreve ./ && mkdir -p /data/aria2 \
&& chmod -R 766 /data/aria2
# prepare permissions and aria2 dir
RUN chmod +x ./cloudreve && mkdir -p /data/aria2 && chmod -R 766 /data/aria2
EXPOSE 5212 EXPOSE 5212
VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"] VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"]

@ -71,7 +71,7 @@ chmod +x ./cloudreve
## :gear: 构建 ## :gear: 构建
自行构建前需要拥有 `Go >= 1.17`、`node.js`、`yarn`、`zip` 等必要依赖。 自行构建前需要拥有 `Go >= 1.18`、`node.js`、`yarn`、`zip` 等必要依赖。
#### 克隆代码 #### 克隆代码

@ -1 +1 @@
Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603 Subproject commit 2bf915a33d58fc78c9c13ffc64685219c28a4732

Binary file not shown.

@ -0,0 +1,432 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embed provides access to files embedded in the running Go program.
//
// Go source files that import "embed" can use the //go:embed directive
// to initialize a variable of type string, []byte, or FS with the contents of
// files read from the package directory or subdirectories at compile time.
//
// For example, here are three ways to embed a file named hello.txt
// and then print its contents at run time.
//
// Embedding one file into a string:
//
// import _ "embed"
//
// //go:embed hello.txt
// var s string
// print(s)
//
// Embedding one file into a slice of bytes:
//
// import _ "embed"
//
// //go:embed hello.txt
// var b []byte
// print(string(b))
//
// Embedded one or more files into a file system:
//
// import "embed"
//
// //go:embed hello.txt
// var f embed.FS
// data, _ := f.ReadFile("hello.txt")
// print(string(data))
//
// # Directives
//
// A //go:embed directive above a variable declaration specifies which files to embed,
// using one or more path.Match patterns.
//
// The directive must immediately precede a line containing the declaration of a single variable.
// Only blank lines and // line comments are permitted between the directive and the declaration.
//
// The type of the variable must be a string type, or a slice of a byte type,
// or FS (or an alias of FS).
//
// For example:
//
// package server
//
// import "embed"
//
// // content holds our static web server content.
// //go:embed image/* template/*
// //go:embed html/index.html
// var content embed.FS
//
// The Go build system will recognize the directives and arrange for the declared variable
// (in the example above, content) to be populated with the matching files from the file system.
//
// The //go:embed directive accepts multiple space-separated patterns for
// brevity, but it can also be repeated, to avoid very long lines when there are
// many patterns. The patterns are interpreted relative to the package directory
// containing the source file. The path separator is a forward slash, even on
// Windows systems. Patterns may not contain . or .. or empty path elements,
// nor may they begin or end with a slash. To match everything in the current
// directory, use * instead of .. To allow for naming files with spaces in
// their names, patterns can be written as Go double-quoted or back-quoted
// string literals.
//
// If a pattern names a directory, all files in the subtree rooted at that directory are
// embedded (recursively), except that files with names beginning with . or _
// are excluded. So the variable in the above example is almost equivalent to:
//
// // content is our static web server content.
// //go:embed image template html/index.html
// var content embed.FS
//
// The difference is that image/* embeds image/.tempfile while image does not.
// Neither embeds image/dir/.tempfile.
//
// If a pattern begins with the prefix all:, then the rule for walking directories is changed
// to include those files beginning with . or _. For example, all:image embeds
// both image/.tempfile and image/dir/.tempfile.
//
// The //go:embed directive can be used with both exported and unexported variables,
// depending on whether the package wants to make the data available to other packages.
// It can only be used with variables at package scope, not with local variables.
//
// Patterns must not match files outside the package's module, such as .git/* or symbolic links.
// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :.
// Matches for empty directories are ignored. After that, each pattern in a //go:embed line
// must match at least one file or non-empty directory.
//
// If any patterns are invalid or have invalid matches, the build will fail.
//
// # Strings and Bytes
//
// The //go:embed line for a variable of type string or []byte can have only a single pattern,
// and that pattern can match only a single file. The string or []byte is initialized with
// the contents of that file.
//
// The //go:embed directive requires importing "embed", even when using a string or []byte.
// In source files that don't refer to embed.FS, use a blank import (import _ "embed").
//
// # File Systems
//
// For embedding a single file, a variable of type string or []byte is often best.
// The FS type enables embedding a tree of files, such as a directory of static
// web server content, as in the example above.
//
// FS implements the io/fs package's FS interface, so it can be used with any package that
// understands file systems, including net/http, text/template, and html/template.
//
// For example, given the content variable in the example above, we can write:
//
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content))))
//
// template.ParseFS(content, "*.tmpl")
//
// # Tools
//
// To support tools that analyze Go packages, the patterns found in //go:embed lines
// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns,
// and XTestEmbedPatterns fields in the “go help list” output.
package bootstrap
import (
"errors"
"io"
"io/fs"
"time"
)
// An FS is a read-only collection of files, usually initialized with a //go:embed directive.
// When declared without a //go:embed directive, an FS is an empty file system.
//
// An FS is a read-only value, so it is safe to use from multiple goroutines
// simultaneously and also safe to assign values of type FS to each other.
//
// FS implements fs.FS, so it can be used with any package that understands
// file system interfaces, including net/http, text/template, and html/template.
//
// See the package documentation for more details about initializing an FS.
type FS struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
//
// The files list is sorted by name but not by simple string comparison.
// Instead, each file's name takes the form "dir/elem" or "dir/elem/".
// The optional trailing slash indicates that the file is itself a directory.
// The files list is sorted first by dir (if dir is missing, it is taken to be ".")
// and then by base, so this list of files:
//
// p
// q/
// q/r
// q/s/
// q/s/t
// q/s/u
// q/v
// w
//
// is actually sorted as:
//
// p # dir=. elem=p
// q/ # dir=. elem=q
// w/ # dir=. elem=w
// q/r # dir=q elem=r
// q/s/ # dir=q elem=s
// q/v # dir=q elem=v
// q/s/t # dir=q/s elem=t
// q/s/u # dir=q/s elem=u
//
// This order brings directory contents together in contiguous sections
// of the list, allowing a directory read to use binary search to find
// the relevant sequence of entries.
files *[]file
}
// split splits the name into dir and elem as described in the
// comment in the FS struct above. isDir reports whether the
// final trailing slash was present, indicating that name is a directory.
func split(name string) (dir, elem string, isDir bool) {
if name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
// trimSlash trims a trailing slash from name, if present,
// returning the possibly shortened name.
func trimSlash(name string) string {
if len(name) > 0 && name[len(name)-1] == '/' {
return name[:len(name)-1]
}
return name
}
var (
_ fs.ReadDirFS = FS{}
_ fs.ReadFileFS = FS{}
)
// A file is a single file in the FS.
// It implements fs.FileInfo and fs.DirEntry.
type file struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
name string
data string
hash [16]byte // truncated SHA256 hash
}
var (
_ fs.FileInfo = (*file)(nil)
_ fs.DirEntry = (*file)(nil)
)
func (f *file) Name() string { _, elem, _ := split(f.name); return elem }
func (f *file) Size() int64 { return int64(len(f.data)) }
func (f *file) ModTime() time.Time { return time.Time{} }
func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir }
func (f *file) Sys() any { return nil }
func (f *file) Type() fs.FileMode { return f.Mode().Type() }
func (f *file) Info() (fs.FileInfo, error) { return f, nil }
func (f *file) Mode() fs.FileMode {
if f.IsDir() {
return fs.ModeDir | 0555
}
return 0444
}
// dotFile is a file for the root directory,
// which is omitted from the files list in a FS.
var dotFile = &file{name: "./"}
// lookup returns the named file, or nil if it is not present.
func (f FS) lookup(name string) *file {
if !fs.ValidPath(name) {
// The compiler should never emit a file with an invalid name,
// so this check is not strictly necessary (if name is invalid,
// we shouldn't find a match below), but it's a good backstop anyway.
return nil
}
if name == "." {
return dotFile
}
if f.files == nil {
return nil
}
// Binary search to find where name would be in the list,
// and then check if name is at that position.
dir, elem, _ := split(name)
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) && trimSlash(files[i].name) == name {
return &files[i]
}
return nil
}
// readDir returns the list of files corresponding to the directory dir.
func (f FS) readDir(dir string) []file {
if f.files == nil {
return nil
}
// Binary search to find where dir starts and ends in the list
// and then return that slice of the list.
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sortSearch(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
// Open opens the named file for reading and returns it as an fs.File.
//
// The returned file implements io.Seeker when the file is not a directory.
func (f FS) Open(name string) (fs.File, error) {
file := f.lookup(name)
if file == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if file.IsDir() {
return &openDir{file, f.readDir(name), 0}, nil
}
return &openFile{file, 0}, nil
}
// ReadDir reads and returns the entire named directory.
func (f FS) ReadDir(name string) ([]fs.DirEntry, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
dir, ok := file.(*openDir)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")}
}
list := make([]fs.DirEntry, len(dir.files))
for i := range list {
list[i] = &dir.files[i]
}
return list, nil
}
// ReadFile reads and returns the content of the named file.
func (f FS) ReadFile(name string) ([]byte, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
ofile, ok := file.(*openFile)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")}
}
return []byte(ofile.f.data), nil
}
// An openFile is a regular file open for reading.
type openFile struct {
f *file // the file itself
offset int64 // current read offset
}
var (
_ io.Seeker = (*openFile)(nil)
)
func (f *openFile) Close() error { return nil }
func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil }
func (f *openFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid}
}
n := copy(b, f.f.data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.data))
}
if offset < 0 || offset > int64(len(f.f.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
// An openDir is a directory open for reading.
type openDir struct {
f *file // the directory file itself
files []file // the directory contents
offset int // the read offset, an index into the files slice
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.files[d.offset+i]
}
d.offset += n
return list, nil
}
// sortSearch is like sort.Search, avoiding an import.
func sortSearch(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}

@ -0,0 +1,75 @@
package bootstrap
import (
"archive/zip"
"crypto/sha256"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/pkg/errors"
"io"
"io/fs"
"sort"
"strings"
)
func NewFS(zipContent string) fs.FS {
zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent)))
if err != nil {
util.Log().Panic("Static resource is not a valid zip file: %s", err)
}
var files []file
err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("无法获取[%s]的信息, %s, 跳过...", path, err)
}
if path == "." {
return nil
}
var f file
if d.IsDir() {
f.name = path + "/"
} else {
f.name = path
rc, err := zipReader.Open(path)
if err != nil {
return errors.Errorf("无法打开文件[%s], %s, 跳过...", path, err)
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return errors.Errorf("无法读取文件[%s], %s, 跳过...", path, err)
}
f.data = string(data)
hash := sha256.Sum256(data)
for i := range f.hash {
f.hash[i] = ^hash[i]
}
}
files = append(files, f)
return nil
})
if err != nil {
util.Log().Panic("初始化静态资源失败: %s", err)
}
sort.Slice(files, func(i, j int) bool {
fi, fj := files[i], files[j]
di, ei, _ := split(fi.name)
dj, ej, _ := split(fj.name)
if di != dj {
return di < dj
}
return ei < ej
})
var embedFS FS
embedFS.files = &files
return embedFS
}

@ -32,11 +32,15 @@ buildAssets() {
yarn run build yarn run build
cd build cd build
cd $REPO cd $REPO
# please keep in mind that if this final output binary `assets.zip` name changed, please go and update the `Dockerfile` as well
zip -r - assets/build >assets.zip zip -r - assets/build >assets.zip
} }
buildBinary() { buildBinary() {
cd $REPO cd $REPO
# same as assets, if this final output binary `cloudreve` name changed, please go and update the `Dockerfile`
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'" go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
} }

@ -1,6 +1,6 @@
module github.com/cloudreve/Cloudreve/v3 module github.com/cloudreve/Cloudreve/v3
go 1.17 go 1.18
require ( require (
github.com/DATA-DOG/go-sqlmock v1.3.3 github.com/DATA-DOG/go-sqlmock v1.3.3
@ -100,6 +100,7 @@ require (
github.com/mattn/go-colorable v0.1.4 // indirect github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/mattn/go-sqlite3 v1.14.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect

@ -4,13 +4,11 @@ import (
"context" "context"
_ "embed" _ "embed"
"flag" "flag"
"io"
"io/fs" "io/fs"
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"syscall" "syscall"
"time" "time"
@ -19,8 +17,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/routers" "github.com/cloudreve/Cloudreve/v3/routers"
"github.com/mholt/archiver/v4"
) )
var ( var (
@ -35,15 +31,12 @@ var staticZip string
var staticFS fs.FS var staticFS fs.FS
func init() { func init() {
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径") flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "Path to the config file.")
flag.BoolVar(&isEject, "eject", false, "导出内置静态资源") flag.BoolVar(&isEject, "eject", false, "Eject all embedded static files.")
flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本") flag.StringVar(&scriptName, "database-script", "", "Name of database util script.")
flag.Parse() flag.Parse()
staticFS = archiver.ArchiveFS{ staticFS = bootstrap.NewFS(staticZip)
Stream: io.NewSectionReader(strings.NewReader(staticZip), 0, int64(len(staticZip))),
Format: archiver.Zip{},
}
bootstrap.Init(confPath, staticFS) bootstrap.Init(confPath, staticFS)
} }
@ -71,7 +64,7 @@ func main() {
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() { go func() {
sig := <-sigChan sig := <-sigChan
util.Log().Info("收到信号 %s开始关闭 server", sig) util.Log().Info("Signal %s received, shutting down server...", sig)
ctx := context.Background() ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 { if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
@ -81,16 +74,16 @@ func main() {
err := server.Shutdown(ctx) err := server.Shutdown(ctx)
if err != nil { if err != nil {
util.Log().Error("关闭 server 错误, %s", err) util.Log().Error("Failed to shutdown server: %s", err)
} }
}() }()
// 如果启用了SSL // 如果启用了SSL
if conf.SSLConfig.CertPath != "" { if conf.SSLConfig.CertPath != "" {
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen) util.Log().Info("Listening to %q", conf.SSLConfig.Listen)
server.Addr = conf.SSLConfig.Listen server.Addr = conf.SSLConfig.Listen
if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil { if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SSLConfig.Listen, err) util.Log().Error("Failed to listen to %q: %s", conf.SSLConfig.Listen, err)
return return
} }
} }
@ -100,23 +93,23 @@ func main() {
// delete socket file before listening // delete socket file before listening
if _, err := os.Stat(conf.UnixConfig.Listen); err == nil { if _, err := os.Stat(conf.UnixConfig.Listen); err == nil {
if err = os.Remove(conf.UnixConfig.Listen); err != nil { if err = os.Remove(conf.UnixConfig.Listen); err != nil {
util.Log().Error("删除 socket 文件错误, %s", err) util.Log().Error("Failed to delete socket file: %s", err)
return return
} }
} }
api.TrustedPlatform = conf.UnixConfig.ProxyHeader api.TrustedPlatform = conf.UnixConfig.ProxyHeader
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen) util.Log().Info("Listening to %q", conf.UnixConfig.Listen)
if err := RunUnix(server); err != nil { if err := RunUnix(server); err != nil {
util.Log().Error("无法监听[%s]%s", conf.UnixConfig.Listen, err) util.Log().Error("Failed to listen to %q: %s", conf.UnixConfig.Listen, err)
} }
return return
} }
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen) util.Log().Info("Listening to %q", conf.SystemConfig.Listen)
server.Addr = conf.SystemConfig.Listen server.Addr = conf.SystemConfig.Listen
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SystemConfig.Listen, err) util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err)
} }
} }
@ -125,8 +118,21 @@ func RunUnix(server *http.Server) error {
if err != nil { if err != nil {
return err return err
} }
defer listener.Close() defer listener.Close()
defer os.Remove(conf.UnixConfig.Listen) defer os.Remove(conf.UnixConfig.Listen)
if conf.UnixConfig.Perm > 0 {
err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(conf.UnixConfig.Perm))
if err != nil {
util.Log().Warning(
"Failed to set permission to %q for socket file %q: %s",
conf.UnixConfig.Perm,
conf.UnixConfig.Listen,
err,
)
}
}
return server.Serve(listener) return server.Serve(listener)
} }

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -45,3 +46,17 @@ func CacheControl() gin.HandlerFunc {
c.Header("Cache-Control", "private, no-cache") c.Header("Cache-Control", "private, no-cache")
} }
} }
func Sandbox() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Content-Security-Policy", "sandbox")
}
}
// StaticResourceCache 使用静态资源缓存策略
func StaticResourceCache() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400)))
}
}

@ -85,3 +85,21 @@ func TestCacheControl(t *testing.T) {
TestFunc(c) TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache") a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
} }
func TestSandbox(t *testing.T) {
a := assert.New(t)
TestFunc := Sandbox()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
}
func TestStaticResourceCache(t *testing.T) {
a := assert.New(t)
TestFunc := StaticResourceCache()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
}

@ -0,0 +1,30 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
)
// ValidateSourceLink validates if the perm source link is a valid redirect link
func ValidateSourceLink() gin.HandlerFunc {
return func(c *gin.Context) {
linkID, ok := c.Get("object_id")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink, err := model.GetSourceLinkByID(linkID)
if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink.Downloaded()
c.Set("source_link", sourceLink)
c.Next()
}
}

@ -0,0 +1,57 @@
package middleware
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestValidateSourceLink(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ValidateSourceLink()
// ID 不存在
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
a.True(c.IsAborted())
}
// SourceLink 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 原文件不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
testFunc(c)
a.False(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
}

@ -39,7 +39,11 @@ func FrontendFileHandler() gin.HandlerFunc {
path := c.Request.URL.Path path := c.Request.URL.Path
// API 跳过 // API 跳过
if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" { if strings.HasPrefix(path, "/api") ||
strings.HasPrefix(path, "/custom") ||
strings.HasPrefix(path, "/dav") ||
strings.HasPrefix(path, "/f") ||
path == "/manifest.json" {
c.Next() c.Next()
return return
} }

@ -46,7 +46,7 @@ func Session(secret string) gin.HandlerFunc {
// Also set Secure: true if using SSL, you should though // Also set Secure: true if using SSL, you should though
Store.Options(sessions.Options{ Store.Options(sessions.Options{
HttpOnly: true, HttpOnly: true,
MaxAge: 7 * 86400, MaxAge: 60 * 86400,
Path: "/", Path: "/",
SameSite: sameSiteMode, SameSite: sameSiteMode,
Secure: conf.CORSConfig.Secure, Secure: conf.CORSConfig.Secure,

@ -113,4 +113,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"}, {Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"}, {Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"}, {Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
{Name: "show_app_promotion", Value: "1", Type: "mobile"},
{Name: "public_resource_maxage", Value: "86400", Type: "timeout"},
} }

@ -32,6 +32,7 @@ type Download struct {
// 数据库忽略字段 // 数据库忽略字段
StatusInfo rpc.StatusInfo `gorm:"-"` StatusInfo rpc.StatusInfo `gorm:"-"`
Task *Task `gorm:"-"` Task *Task `gorm:"-"`
NodeName string `gorm:"-"`
} }
// AfterFind 找到下载任务后的钩子处理Status结构 // AfterFind 找到下载任务后的钩子处理Status结构

@ -4,6 +4,7 @@ import (
"encoding/gob" "encoding/gob"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"path" "path"
"time" "time"
@ -191,14 +192,15 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) {
} }
// 查询软链接的文件 // 查询软链接的文件
var filesWithSoftLinks []File filesWithSoftLinks := make([]File, 0)
tx := DB for _, file := range files {
for _, value := range files { var softLinkFile File
tx = tx.Or("source_name = ? and policy_id = ? and id != ?", value.SourceName, value.PolicyID, value.ID) res := DB.
Where("source_name = ? and policy_id = ? and id != ?", file.SourceName, file.PolicyID, file.ID).
First(&softLinkFile)
if res.Error == nil {
filesWithSoftLinks = append(filesWithSoftLinks, softLinkFile)
} }
result := tx.Find(&filesWithSoftLinks)
if result.Error != nil {
return nil, result.Error
} }
// 过滤具有软连接的文件 // 过滤具有软连接的文件
@ -338,6 +340,25 @@ func (file *File) CanCopy() bool {
return file.UploadSessionID == nil return file.UploadSessionID == nil
} }
// CreateOrGetSourceLink creates a SourceLink model. If the given model exists, the existing
// model will be returned.
func (file *File) CreateOrGetSourceLink() (*SourceLink, error) {
res := &SourceLink{}
err := DB.Set("gorm:auto_preload", true).Where("file_id = ?", file.ID).Find(&res).Error
if err == nil && res.ID > 0 {
return res, nil
}
res.FileID = file.ID
res.Name = file.Name
if err := DB.Save(res).Error; err != nil {
return nil, fmt.Errorf("failed to insert SourceLink: %w", err)
}
res.File = *file
return res, nil
}
/* /*
webdav.FileInfo webdav.FileInfo
*/ */

@ -285,30 +285,34 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
}, },
} }
// 传入空文件列表
{
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.NoError(err)
asserts.Empty(file)
}
// 全都没有 // 全都没有
{ {
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2). WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files) file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err) asserts.NoError(err)
asserts.Equal(files, file) asserts.Equal(files, file)
} }
// 查询出错
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnError(errors.New("error"))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Nil(file)
}
// 第二个是软链 // 第二个是软链
{ {
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2). WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows( WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"), AddRow(3, 24, "2.txt"),
@ -318,14 +322,18 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts.NoError(err) asserts.NoError(err)
asserts.Equal(files[:1], file) asserts.Equal(files[:1], file)
} }
// 第一个是软链 // 第一个是软链
{ {
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2). WithArgs("1.txt", 23, 1).
WillReturnRows( WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"), AddRow(3, 23, "1.txt"),
) )
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files) file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err) asserts.NoError(err)
@ -334,11 +342,16 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
// 全部是软链 // 全部是软链
{ {
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2). WithArgs("1.txt", 23, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows( WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"). AddRow(3, 24, "2.txt"),
AddRow(4, 23, "1.txt"),
) )
file, err := RemoveFilesWithSoftLinks(files) file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(mock.ExpectationsWereMet())
@ -598,3 +611,44 @@ func TestGetFilesByKeywords(t *testing.T) {
asserts.Len(res, 1) asserts.Len(res, 1)
} }
} }
func TestFile_CreateOrGetSourceLink(t *testing.T) {
a := assert.New(t)
file := &File{}
file.ID = 1
// 已存在,返回老的 SourceLink
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
// 不存在,插入失败
{
expectedErr := errors.New("error")
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr)
mock.ExpectRollback()
res, err := file.CreateOrGetSourceLink()
a.Nil(res)
a.ErrorIs(err, expectedErr)
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.EqualValues(file.ID, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
}

@ -224,7 +224,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
} else if IDCache, ok := newIDCache[*folder.ParentID]; ok { } else if IDCache, ok := newIDCache[*folder.ParentID]; ok {
newID = IDCache newID = IDCache
} else { } else {
util.Log().Warning("Failed to get parent folder %q", folder.ParentID) util.Log().Warning("Failed to get parent folder %q", *folder.ParentID)
return size, errors.New("Failed to get parent folder") return size, errors.New("Failed to get parent folder")
} }

@ -32,6 +32,7 @@ type GroupOption struct {
Aria2 bool `json:"aria2,omitempty"` // 离线下载 Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"` SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"` Aria2BatchSize int `json:"aria2_batch,omitempty"`
} }

@ -41,7 +41,7 @@ func migration() {
} }
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{}, DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}) &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{})
// 创建初始存储策略 // 创建初始存储策略
addDefaultPolicy() addDefaultPolicy()
@ -110,6 +110,7 @@ func addDefaultGroups() {
Aria2: true, Aria2: true,
SourceBatchSize: 1000, SourceBatchSize: 1000,
Aria2BatchSize: 50, Aria2BatchSize: 50,
RedirectedSource: true,
}, },
} }
if err := DB.Create(&defaultAdminGroup).Error; err != nil { if err := DB.Create(&defaultAdminGroup).Error; err != nil {
@ -131,6 +132,7 @@ func addDefaultGroups() {
ShareDownload: true, ShareDownload: true,
SourceBatchSize: 10, SourceBatchSize: 10,
Aria2BatchSize: 1, Aria2BatchSize: 1,
RedirectedSource: true,
}, },
} }
if err := DB.Create(&defaultAdminGroup).Error; err != nil { if err := DB.Create(&defaultAdminGroup).Error; err != nil {

@ -15,12 +15,12 @@ var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error { func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok { if script, ok := availableScripts[name]; ok {
util.Log().Info("开始执行数据库脚本 [%s]", name) util.Log().Info("Start executing database script %q.", name)
script.Run(ctx) script.Run(ctx)
return nil return nil
} }
return fmt.Errorf("数据库脚本 [%s] 不存在", name) return fmt.Errorf("Database script %q not exist.", name)
} }
func Register(name string, script DBScript) { func Register(name string, script DBScript) {

@ -14,7 +14,7 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 查找用户 // 查找用户
user, err := model.GetUserByID(1) user, err := model.GetUserByID(1)
if err != nil { if err != nil {
util.Log().Panic("初始管理员用户不存在, %s", err) util.Log().Panic("Initial admin user not exist: %s", err)
} }
// 生成密码 // 生成密码
@ -23,9 +23,9 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 更改为新密码 // 更改为新密码
user.SetPassword(password) user.SetPassword(password)
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil { if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
util.Log().Panic("密码更改失败, %s", err) util.Log().Panic("Failed to update password: %s", err)
} }
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold) c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password)) util.Log().Info("Initial admin user password changed to:" + c.Sprint(password))
} }

@ -25,7 +25,7 @@ func (script UserStorageCalibration) Run(ctx context.Context) {
model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total) model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total)
// 更新用户的容量 // 更新用户的容量
if user.Storage != total.Total { if user.Storage != total.Total {
util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email, util.Log().Info("Calibrate used storage for user %q, from %d to %d.", user.Email,
user.Storage, total.Total) user.Storage, total.Total)
} }
model.DB.Model(&user).Update("storage", total.Total) model.DB.Model(&user).Update("storage", total.Total)

@ -0,0 +1,47 @@
package model
import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/jinzhu/gorm"
"net/url"
)
// SourceLink represent a shared file source link
type SourceLink struct {
gorm.Model
FileID uint // corresponding file ID
Name string // name of the file while creating the source link, for annotation
Downloads int // 下载数
// 关联模型
File File `gorm:"save_associations:false:false"`
}
// Link gets the URL of a SourceLink
func (s *SourceLink) Link() (string, error) {
baseURL := GetSiteURL()
linkPath, err := url.Parse(fmt.Sprintf("/f/%s/%s", hashid.HashID(s.ID, hashid.SourceLinkID), s.File.Name))
if err != nil {
return "", err
}
return baseURL.ResolveReference(linkPath).String(), nil
}
// GetTasksByID queries source link based on ID
func GetSourceLinkByID(id interface{}) (*SourceLink, error) {
link := &SourceLink{}
result := DB.Where("id = ?", id).First(link)
files, _ := GetFilesByIDs([]uint{link.FileID}, 0)
if len(files) > 0 {
link.File = files[0]
}
return link, result.Error
}
// Viewed 增加访问次数
func (s *SourceLink) Downloaded() {
s.Downloads++
DB.Model(s).UpdateColumn("downloads", gorm.Expr("downloads + ?", 1))
}

@ -0,0 +1,52 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSourceLink_Link(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
// 失败
{
s.File.Name = string([]byte{0x7f})
res, err := s.Link()
a.Error(err)
a.Empty(res)
}
// 成功
{
s.File.Name = "filename"
res, err := s.Link()
a.NoError(err)
a.Contains(res, s.Name)
}
}
func TestGetSourceLinkByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := GetSourceLinkByID(1)
a.NoError(err)
a.NotNil(res)
a.EqualValues(2, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestSourceLink_Downloaded(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
s.Downloaded()
a.NoError(mock.ExpectationsWereMet())
}

@ -52,7 +52,7 @@ const (
var ( var (
// ErrNotEnabled 功能未开启错误 // ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "", nil) ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
// ErrUserNotFound 未找到下载任务创建者 // ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil) ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
) )

@ -38,6 +38,7 @@ type ssl struct {
type unix struct { type unix struct {
Listen string Listen string
ProxyHeader string `validate:"required_with=Listen"` ProxyHeader string `validate:"required_with=Listen"`
Perm uint32
} }
// slave 作为slave存储端配置 // slave 作为slave存储端配置

@ -1,13 +1,13 @@
package conf package conf
// BackendVersion 当前后端版本号 // BackendVersion 当前后端版本号
var BackendVersion = "3.5.3" var BackendVersion = "3.6.0"
// RequiredDBVersion 与当前版本匹配的数据库版本 // RequiredDBVersion 与当前版本匹配的数据库版本
var RequiredDBVersion = "3.5.2" var RequiredDBVersion = "3.6.0"
// RequiredStaticVersion 与当前版本匹配的静态资源版本 // RequiredStaticVersion 与当前版本匹配的静态资源版本
var RequiredStaticVersion = "3.5.3" var RequiredStaticVersion = "3.6.0"
// IsPro 是否为Pro版本 // IsPro 是否为Pro版本
var IsPro = "false" var IsPro = "false"

@ -1,14 +1,22 @@
package backoff package backoff
import "time" import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strconv"
"time"
)
// Backoff used for retry sleep backoff // Backoff used for retry sleep backoff
type Backoff interface { type Backoff interface {
Next() bool Next(err error) bool
Reset() Reset()
} }
// ConstantBackoff implements Backoff interface with constant sleep time // ConstantBackoff implements Backoff interface with constant sleep time. If the error
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
type ConstantBackoff struct { type ConstantBackoff struct {
Sleep time.Duration Sleep time.Duration
Max int Max int
@ -16,16 +24,51 @@ type ConstantBackoff struct {
tried int tried int
} }
func (c *ConstantBackoff) Next() bool { func (c *ConstantBackoff) Next(err error) bool {
c.tried++ c.tried++
if c.tried > c.Max { if c.tried > c.Max {
return false return false
} }
var e *RetryableError
if errors.As(err, &e) && e.RetryAfter > 0 {
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
time.Sleep(e.RetryAfter)
} else {
time.Sleep(c.Sleep) time.Sleep(c.Sleep)
}
return true return true
} }
func (c *ConstantBackoff) Reset() { func (c *ConstantBackoff) Reset() {
c.tried = 0 c.tried = 0
} }
type RetryableError struct {
Err error
RetryAfter time.Duration
}
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
// and existing error.
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
retryAfter := header.Get("retry-after")
if retryAfter == "" {
retryAfter = "0"
}
res := &RetryableError{
Err: err,
}
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
}
return res
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
}

@ -1,7 +1,9 @@
package backoff package backoff
import ( import (
"errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http"
"testing" "testing"
"time" "time"
) )
@ -9,14 +11,51 @@ import (
func TestConstantBackoff_Next(t *testing.T) { func TestConstantBackoff_Next(t *testing.T) {
a := assert.New(t) a := assert.New(t)
// General error
{
err := errors.New("error")
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next()) a.True(b.Next(err))
a.True(b.Next()) a.True(b.Next(err))
a.True(b.Next()) a.True(b.Next(err))
a.False(b.Next()) a.False(b.Next(err))
b.Reset() b.Reset()
a.True(b.Next()) a.True(b.Next(err))
a.True(b.Next()) a.True(b.Next(err))
a.True(b.Next()) a.True(b.Next(err))
a.False(b.Next()) a.False(b.Next(err))
}
// Retryable error
{
err := &RetryableError{RetryAfter: time.Duration(1)}
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
}
func TestNewRetryableErrorFromHeader(t *testing.T) {
a := assert.New(t)
// no retry-after header
{
err := NewRetryableErrorFromHeader(nil, http.Header{})
a.Empty(err.RetryAfter)
}
// with retry-after header
{
header := http.Header{}
header.Add("retry-after", "120")
err := NewRetryableErrorFromHeader(nil, header)
a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter)
}
} }

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"io" "io"
"os" "os"
@ -66,7 +67,7 @@ func (c *ChunkGroup) TempAvailable() bool {
// Process a chunk with retry logic // Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
reader := io.LimitReader(c.file, int64(c.chunkSize)) reader := io.LimitReader(c.file, c.Length())
// If useBuffer is enabled, tee the reader to a temp file // If useBuffer is enabled, tee the reader to a temp file
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() { if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
@ -90,13 +91,17 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
} }
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name()) util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
reader = c.bufferTemp reader = io.NopCloser(c.bufferTemp)
} }
} }
err := processor(c, reader) err := processor(c, reader)
if err != nil { if err != nil {
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next() { if c.enableRetryBuffer {
request.BlackHole(reader)
}
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
if c.file.Seekable() { if c.file.Seekable() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil { if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err) return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)

@ -36,7 +36,7 @@ func TestHandler_Put(t *testing.T) {
{&fsctx.FileStream{ {&fsctx.FileStream{
SavePath: "TestHandler_Put.txt", SavePath: "TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")), File: io.NopCloser(strings.NewReader("")),
}, "物理同名文件已存在或不可用"}, }, "file with the same name existed or unavailable"},
{&fsctx.FileStream{ {&fsctx.FileStream{
SavePath: "inner/TestHandler_Put.txt", SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")), File: io.NopCloser(strings.NewReader("")),
@ -51,7 +51,7 @@ func TestHandler_Put(t *testing.T) {
Mode: fsctx.Append | fsctx.Overwrite, Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt", SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("123")), File: io.NopCloser(strings.NewReader("123")),
}, "未上传完成的文件分片与预期大小不一致"}, }, "size of unfinished uploaded chunks is not as expected"},
{&fsctx.FileStream{ {&fsctx.FileStream{
Mode: fsctx.Append | fsctx.Overwrite, Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt", SavePath: "inner/TestHandler_Put.txt",

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -37,24 +36,18 @@ const (
// GetSourcePath 获取文件的绝对路径 // GetSourcePath 获取文件的绝对路径
func (info *FileInfo) GetSourcePath() string { func (info *FileInfo) GetSourcePath() string {
res, err := url.PathUnescape( res, err := url.PathUnescape(info.ParentReference.Path)
strings.TrimPrefix( if err != nil {
return ""
}
return strings.TrimPrefix(
path.Join( path.Join(
strings.TrimPrefix(info.ParentReference.Path, "/drive/root:"), strings.TrimPrefix(res, "/drive/root:"),
info.Name, info.Name,
), ),
"/", "/",
),
) )
if err != nil {
return ""
}
return res
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
} }
func (client *Client) getRequestURL(api string, opts ...Option) string { func (client *Client) getRequestURL(api string, opts ...Option) string {
@ -531,7 +524,7 @@ func sysError(err error) *RespError {
}} }}
} }
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, *RespError) { func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
// 获取凭证 // 获取凭证
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave") err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
if err != nil { if err != nil {
@ -580,15 +573,21 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
util.Log().Debug("Onedrive returns unknown response: %s", respBody) util.Log().Debug("Onedrive returns unknown response: %s", respBody)
return "", sysError(decodeErr) return "", sysError(decodeErr)
} }
if res.Response.StatusCode == 429 {
util.Log().Warning("OneDrive request is throttled.")
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
}
return "", &errResp return "", &errResp
} }
return respBody, nil return respBody, nil
} }
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, *RespError) { func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
// 发送请求 // 发送请求
bodyReader := ioutil.NopCloser(strings.NewReader(body)) bodyReader := io.NopCloser(strings.NewReader(body))
return client.request(ctx, method, url, bodyReader, return client.request(ctx, method, url, bodyReader,
request.WithContentLength(int64(len(body))), request.WithContentLength(int64(len(body))),
) )

@ -112,6 +112,35 @@ func TestRequest(t *testing.T) {
asserts.Equal("error msg", err.Error()) asserts.Equal("error msg", err.Error())
} }
// OneDrive返回429错误
{
header := http.Header{}
header.Add("retry-after", "120")
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://dev.com",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 429,
Header: header,
Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)),
},
})
client.Request = clientMock
res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Empty(res)
var retryErr *backoff.RetryableError
asserts.ErrorAs(err, &retryErr)
asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter)
}
// OneDrive返回未知响应 // OneDrive返回未知响应
{ {
clientMock := ClientMock{} clientMock := ClientMock{}
@ -144,18 +173,18 @@ func TestFileInfo_GetSourcePath(t *testing.T) {
fileInfo := FileInfo{ fileInfo := FileInfo{
Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg", Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg",
ParentReference: parentReference{ ParentReference: parentReference{
Path: "/drive/root:/123/321", Path: "/drive/root:/123/32%201",
}, },
} }
asserts.Equal("123/321/文件名.jpg", fileInfo.GetSourcePath()) asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath())
} }
// 失败 // 失败
{ {
fileInfo := FileInfo{ fileInfo := FileInfo{
Name: "%e6%96%87%e4%bb%b6%e5%90%8g.jpg", Name: "123.jpg",
ParentReference: parentReference{ ParentReference: parentReference{
Path: "/drive/root:/123/321", Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g",
}, },
} }
asserts.Equal("", fileInfo.GetSourcePath()) asserts.Equal("", fileInfo.GetSourcePath())

@ -11,7 +11,6 @@ import (
"time" "time"
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
@ -171,19 +170,6 @@ func (handler Driver) Source(
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path) cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID) cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
// 如果是永久链接,则返回签名后的中转外链
if ttl == 0 {
signedURI, err := auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/source/%d/%s", file.ID, file.Name),
ttl,
)
if err != nil {
return "", err
}
return baseURL.ResolveReference(signedURI).String(), nil
}
} }
// 尝试从缓存中查找 // 尝试从缓存中查找

@ -3,7 +3,6 @@ package onedrive
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -161,21 +160,6 @@ func TestDriver_Source(t *testing.T) {
asserts.NoError(err) asserts.NoError(err)
asserts.Equal("123321", res) asserts.Equal("123321", res)
} }
// 成功 永久直链
{
file := model.File{}
file.ID = 1
file.Name = "123.jpg"
file.UpdatedAt = time.Now()
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
auth.General = auth.HMACAuth{}
handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
asserts.NoError(err)
asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign")
}
} }
func TestDriver_List(t *testing.T) { func TestDriver_List(t *testing.T) {

@ -133,3 +133,8 @@ type Site struct {
func init() { func init() {
gob.Register(Credential{}) gob.Register(Credential{})
} }
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}

@ -398,8 +398,8 @@ func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *seri
// Meta 获取文件信息 // Meta 获取文件信息
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) { func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.svc.GetObject( res, err := handler.svc.HeadObject(
&s3.GetObjectInput{ &s3.HeadObjectInput{
Bucket: &handler.Policy.BucketName, Bucket: &handler.Policy.BucketName,
Key: &path, Key: &path,
}) })

@ -8,17 +8,17 @@ import (
var ( var (
ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil) ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil)
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "", nil) ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil)
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "", nil) ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil)
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "", nil) ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil)
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "", nil) ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil)
ErrClientCanceled = errors.New("Client canceled operation") ErrClientCanceled = errors.New("Client canceled operation")
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "", nil) ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil)
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil) ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "", nil) ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil)
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "", nil) ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil)
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil) ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil)
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil) ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil)
ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil) ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil)
ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil) ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil)
ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil) ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil)

@ -472,6 +472,9 @@ func TestFileSystem_Delete(t *testing.T) {
AddRow(4, "1.txt", "1.txt", 365, 1), AddRow(4, "1.txt", "1.txt", 365, 1),
) )
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2))
// 两次查询软连接
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
// 查询上传策略 // 查询上传策略
@ -527,6 +530,9 @@ func TestFileSystem_Delete(t *testing.T) {
AddRow(4, "1.txt", "1.txt", 602, 1), AddRow(4, "1.txt", "1.txt", 602, 1),
) )
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2)) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2))
// 两次查询软连接
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
// 查询上传策略 // 查询上传策略

@ -15,11 +15,12 @@ const (
FolderID // 目录ID FolderID // 目录ID
TagID // 标签ID TagID // 标签ID
PolicyID // 存储策略ID PolicyID // 存储策略ID
SourceLinkID
) )
var ( var (
// ErrTypeNotMatch ID类型不匹配 // ErrTypeNotMatch ID类型不匹配
ErrTypeNotMatch = errors.New("ID类型不匹配") ErrTypeNotMatch = errors.New("mismatched ID type.")
) )
// HashEncode 对给定数据计算HashID // HashEncode 对给定数据计算HashID

@ -44,6 +44,12 @@ func newDefaultOption() *options {
} }
} }
func (o *options) clone() options {
newOptions := *o
newOptions.header = o.header.Clone()
return newOptions
}
// WithTimeout 设置请求超时 // WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option { func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) { return optionFunc(func(o *options) {

@ -56,7 +56,7 @@ func NewClient(opts ...Option) Client {
func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置 // 应用额外设置
c.mu.Lock() c.mu.Lock()
options := *c.options options := c.options.clone()
c.mu.Unlock() c.mu.Unlock()
for _, o := range opts { for _, o := range opts {
o.apply(&options) o.apply(&options)
@ -179,7 +179,7 @@ func (resp *Response) DecodeResponse() (*serializer.Response, error) {
var res serializer.Response var res serializer.Response
err = json.Unmarshal([]byte(respString), &res) err = json.Unmarshal([]byte(respString), &res)
if err != nil { if err != nil {
util.Log().Debug("无法解析回调服务端响应:%s", string(respString)) util.Log().Debug("Failed to parse response: %s", string(respString))
return nil, err return nil, err
} }
return &res, nil return &res, nil
@ -251,7 +251,7 @@ func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) {
return instance.status.Size, nil return instance.status.Size, nil
} }
} }
return 0, errors.New("未实现") return 0, errors.New("not implemented")
} }

@ -19,6 +19,7 @@ type DownloadListResponse struct {
Downloaded uint64 `json:"downloaded"` Downloaded uint64 `json:"downloaded"`
Speed int `json:"speed"` Speed int `json:"speed"`
Info rpc.StatusInfo `json:"info"` Info rpc.StatusInfo `json:"info"`
NodeName string `json:"node"`
} }
// FinishedListResponse 已完成任务条目 // FinishedListResponse 已完成任务条目
@ -34,6 +35,7 @@ type FinishedListResponse struct {
TaskError string `json:"task_error"` TaskError string `json:"task_error"`
CreateTime time.Time `json:"create"` CreateTime time.Time `json:"create"`
UpdateTime time.Time `json:"update"` UpdateTime time.Time `json:"update"`
NodeName string `json:"node"`
} }
// BuildFinishedListResponse 构建已完成任务条目 // BuildFinishedListResponse 构建已完成任务条目
@ -62,6 +64,7 @@ func BuildFinishedListResponse(tasks []model.Download) Response {
TaskStatus: -1, TaskStatus: -1,
UpdateTime: tasks[i].UpdatedAt, UpdateTime: tasks[i].UpdatedAt,
CreateTime: tasks[i].CreatedAt, CreateTime: tasks[i].CreatedAt,
NodeName: tasks[i].NodeName,
} }
if tasks[i].Task != nil { if tasks[i].Task != nil {
@ -106,6 +109,7 @@ func BuildDownloadingResponse(tasks []model.Download, intervals map[uint]int) Re
Downloaded: tasks[i].DownloadedSize, Downloaded: tasks[i].DownloadedSize,
Speed: tasks[i].Speed, Speed: tasks[i].Speed,
Info: tasks[i].StatusInfo, Info: tasks[i].StatusInfo,
NodeName: tasks[i].NodeName,
}) })
} }

@ -221,7 +221,7 @@ const (
// DBErr 数据库操作失败 // DBErr 数据库操作失败
func DBErr(msg string, err error) Response { func DBErr(msg string, err error) Response {
if msg == "" { if msg == "" {
msg = "数据库操作失败" msg = "Database operation failed."
} }
return Err(CodeDBError, msg, err) return Err(CodeDBError, msg, err)
} }
@ -229,7 +229,7 @@ func DBErr(msg string, err error) Response {
// ParamErr 各种参数错误 // ParamErr 各种参数错误
func ParamErr(msg string, err error) Response { func ParamErr(msg string, err error) Response {
if msg == "" { if msg == "" {
msg = "参数错误" msg = "Invalid parameters."
} }
return Err(CodeParamErr, msg, err) return Err(CodeParamErr, msg, err)
} }

@ -19,7 +19,7 @@ func NewResponseWithGobData(data interface{}) Response {
var w bytes.Buffer var w bytes.Buffer
encoder := gob.NewEncoder(&w) encoder := gob.NewEncoder(&w)
if err := encoder.Encode(data); err != nil { if err := encoder.Encode(data); err != nil {
return Err(CodeInternalSetting, "无法编码返回结果", err) return Err(CodeInternalSetting, "Failed to encode response content", err)
} }
return Response{Data: w.Bytes()} return Response{Data: w.Bytes()}

@ -22,6 +22,7 @@ type SiteConfig struct {
CaptchaType string `json:"captcha_type"` CaptchaType string `json:"captcha_type"`
TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"` TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"`
RegisterEnabled bool `json:"registerEnabled"` RegisterEnabled bool `json:"registerEnabled"`
AppPromotion bool `json:"app_promotion"`
} }
type task struct { type task struct {
@ -83,6 +84,7 @@ func BuildSiteConfig(settings map[string]string, user *model.User) Response {
CaptchaType: checkSettingValue(settings, "captcha_type"), CaptchaType: checkSettingValue(settings, "captcha_type"),
TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"), TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"),
RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")), RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")),
AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")),
}} }}
return res return res
} }

@ -19,14 +19,3 @@ func TestSlaveTransferReq_Hash(t *testing.T) {
} }
a.NotEqual(s1.Hash("1"), s2.Hash("1")) a.NotEqual(s1.Hash("1"), s2.Hash("1"))
} }
func TestSlaveRecycleReq_Hash(t *testing.T) {
a := assert.New(t)
s1 := &SlaveRecycleReq{
Path: "1",
}
s2 := &SlaveRecycleReq{
Path: "2",
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}

@ -13,7 +13,7 @@ import (
func CheckLogin() Response { func CheckLogin() Response {
return Response{ return Response{
Code: CodeCheckLogin, Code: CodeCheckLogin,
Msg: "未登录", Msg: "Login required",
} }
} }

@ -69,7 +69,7 @@ func (job *CompressTask) SetError(err *JobError) {
func (job *CompressTask) removeZipFile() { func (job *CompressTask) removeZipFile() {
if job.zipPath != "" { if job.zipPath != "" {
if err := os.Remove(job.zipPath); err != nil { if err := os.Remove(job.zipPath); err != nil {
util.Log().Warning("无法删除临时压缩文件 %s , %s", job.zipPath, err) util.Log().Warning("Failed to delete temp zip file %q: %s", job.zipPath, err)
} }
} }
} }
@ -93,7 +93,7 @@ func (job *CompressTask) Do() {
return return
} }
util.Log().Debug("开始压缩文件") util.Log().Debug("Starting compress file...")
job.TaskModel.SetProgress(CompressingProgress) job.TaskModel.SetProgress(CompressingProgress)
// 创建临时压缩文件 // 创建临时压缩文件
@ -122,7 +122,7 @@ func (job *CompressTask) Do() {
job.zipPath = zipFilePath job.zipPath = zipFilePath
zipFile.Close() zipFile.Close()
util.Log().Debug("压缩文件存放至%s开始上传", zipFilePath) util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath)
job.TaskModel.SetProgress(TransferringProgress) job.TaskModel.SetProgress(TransferringProgress)
// 上传文件 // 上传文件

@ -77,7 +77,7 @@ func (job *DecompressTask) Do() {
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystem(job.User) fs, err := filesystem.NewFileSystem(job.User)
if err != nil { if err != nil {
job.SetErrorMsg("无法创建文件系统", err) job.SetErrorMsg("Failed to create filesystem.", err)
return return
} }
@ -85,7 +85,7 @@ func (job *DecompressTask) Do() {
err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst, job.TaskProps.Encoding) err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst, job.TaskProps.Encoding)
if err != nil { if err != nil {
job.SetErrorMsg("解压缩失败", err) job.SetErrorMsg("Failed to decompress file.", err)
return return
} }

@ -4,5 +4,5 @@ import "errors"
var ( var (
// ErrUnknownTaskType 未知任务类型 // ErrUnknownTaskType 未知任务类型
ErrUnknownTaskType = errors.New("未知任务类型") ErrUnknownTaskType = errors.New("unknown task type")
) )

@ -81,7 +81,7 @@ func (job *ImportTask) Do() {
// 查找存储策略 // 查找存储策略
policy, err := model.GetPolicyByID(job.TaskProps.PolicyID) policy, err := model.GetPolicyByID(job.TaskProps.PolicyID)
if err != nil { if err != nil {
job.SetErrorMsg("找不到存储策略", err) job.SetErrorMsg("Policy not exist.", err)
return return
} }
@ -96,7 +96,7 @@ func (job *ImportTask) Do() {
fs.Policy = &policy fs.Policy = &policy
if err := fs.DispatchHandler(); err != nil { if err := fs.DispatchHandler(); err != nil {
job.SetErrorMsg("无法分发存储策略", err) job.SetErrorMsg("Failed to dispatch policy.", err)
return return
} }
@ -110,7 +110,7 @@ func (job *ImportTask) Do() {
true) true)
objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive) objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive)
if err != nil { if err != nil {
job.SetErrorMsg("无法列取文件", err) job.SetErrorMsg("Failed to list files.", err)
return return
} }
@ -126,7 +126,7 @@ func (job *ImportTask) Do() {
virtualPath := path.Join(job.TaskProps.Dst, object.RelativePath) virtualPath := path.Join(job.TaskProps.Dst, object.RelativePath)
folder, err := fs.CreateDirectory(coxIgnoreConflict, virtualPath) folder, err := fs.CreateDirectory(coxIgnoreConflict, virtualPath)
if err != nil { if err != nil {
util.Log().Warning("导入任务无法创建用户目录[%s], %s", virtualPath, err) util.Log().Warning("Importing task cannot create user directory %q: %s", virtualPath, err)
} else if folder.ID > 0 { } else if folder.ID > 0 {
pathCache[virtualPath] = folder pathCache[virtualPath] = folder
} }
@ -152,7 +152,7 @@ func (job *ImportTask) Do() {
} else { } else {
folder, err := fs.CreateDirectory(context.Background(), virtualPath) folder, err := fs.CreateDirectory(context.Background(), virtualPath)
if err != nil { if err != nil {
util.Log().Warning("导入任务无法创建用户目录[%s], %s", util.Log().Warning("Importing task cannot create user directory %q: %s",
virtualPath, err) virtualPath, err)
continue continue
} }
@ -163,10 +163,10 @@ func (job *ImportTask) Do() {
// 插入文件记录 // 插入文件记录
_, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) _, err := fs.AddFile(context.Background(), parentFolder, &fileHeader)
if err != nil { if err != nil {
util.Log().Warning("导入任务无法创插入文件[%s], %s", util.Log().Warning("Importing task cannot insert user file %q: %s",
object.RelativePath, err) object.RelativePath, err)
if err == filesystem.ErrInsufficientCapacity { if err == filesystem.ErrInsufficientCapacity {
job.SetErrorMsg("容量不足", err) job.SetErrorMsg("Insufficient storage capacity.", err)
return return
} }
} }

@ -89,12 +89,12 @@ func Resume(p Pool) {
if len(tasks) == 0 { if len(tasks) == 0 {
return return
} }
util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks)) util.Log().Info("Resume %d unfinished task(s) from database.", len(tasks))
for i := 0; i < len(tasks); i++ { for i := 0; i < len(tasks); i++ {
job, err := GetJobFromModel(&tasks[i]) job, err := GetJobFromModel(&tasks[i])
if err != nil { if err != nil {
util.Log().Warning("无法恢复任务,%s", err) util.Log().Warning("Failed to resume task: %s", err)
continue continue
} }

@ -44,11 +44,11 @@ func (pool *AsyncPool) freeWorker() {
// Submit 开始提交任务 // Submit 开始提交任务
func (pool *AsyncPool) Submit(job Job) { func (pool *AsyncPool) Submit(job Job) {
go func() { go func() {
util.Log().Debug("等待获取Worker") util.Log().Debug("Waiting for Worker.")
worker := pool.obtainWorker() worker := pool.obtainWorker()
util.Log().Debug("获取到Worker") util.Log().Debug("Worker obtained.")
worker.Do(job) worker.Do(job)
util.Log().Debug("释放Worker") util.Log().Debug("Worker released.")
pool.freeWorker() pool.freeWorker()
}() }()
} }
@ -60,7 +60,7 @@ func Init() {
idleWorker: make(chan int, maxWorker), idleWorker: make(chan int, maxWorker),
} }
TaskPoll.Add(maxWorker) TaskPoll.Add(maxWorker)
util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker) util.Log().Info("Initialize task queue with WorkerNum = %d", maxWorker)
if conf.SystemConfig.Mode == "master" { if conf.SystemConfig.Mode == "master" {
Resume(TaskPoll) Resume(TaskPoll)

@ -73,21 +73,21 @@ func (job *RecycleTask) GetError() *JobError {
func (job *RecycleTask) Do() { func (job *RecycleTask) Do() {
download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID) download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID)
if err != nil { if err != nil {
util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID) util.Log().Warning("Recycle task %d cannot found download record.", job.TaskModel.ID)
job.SetErrorMsg("无法找到下载任务", err) job.SetErrorMsg("Cannot found download task.", err)
return return
} }
nodeID := download.GetNodeID() nodeID := download.GetNodeID()
node := cluster.Default.GetNodeByID(nodeID) node := cluster.Default.GetNodeByID(nodeID)
if node == nil { if node == nil {
util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID) util.Log().Warning("Recycle task %d cannot found node.", job.TaskModel.ID)
job.SetErrorMsg("从机节点不可用", nil) job.SetErrorMsg("Invalid slave node.", nil)
return return
} }
err = node.GetAria2Instance().DeleteTempFile(download) err = node.GetAria2Instance().DeleteTempFile(download)
if err != nil { if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err) util.Log().Warning("Failed to delete transfer temp folder %q: %s", download.Parent, err)
job.SetErrorMsg("文件回收失败", err) job.SetErrorMsg("Failed to recycle files.", err)
return return
} }
} }

@ -69,7 +69,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) {
} }
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
util.Log().Warning("无法发送转存失败通知到从机, %s", err) util.Log().Warning("Failed to send transfer failure notification to master node: %s", err)
} }
} }
@ -82,26 +82,26 @@ func (job *TransferTask) GetError() *task.JobError {
func (job *TransferTask) Do() { func (job *TransferTask) Do() {
fs, err := filesystem.NewAnonymousFileSystem() fs, err := filesystem.NewAnonymousFileSystem()
if err != nil { if err != nil {
job.SetErrorMsg("无法初始化匿名文件系统", err) job.SetErrorMsg("Failed to initialize anonymous filesystem.", err)
return return
} }
fs.Policy = job.Req.Policy fs.Policy = job.Req.Policy
if err := fs.DispatchHandler(); err != nil { if err := fs.DispatchHandler(); err != nil {
job.SetErrorMsg("无法分发存储策略", err) job.SetErrorMsg("Failed to dispatch policy.", err)
return return
} }
master, err := cluster.DefaultController.GetMasterInfo(job.MasterID) master, err := cluster.DefaultController.GetMasterInfo(job.MasterID)
if err != nil { if err != nil {
job.SetErrorMsg("找不到主机节点", err) job.SetErrorMsg("Cannot found master node ID.", err)
return return
} }
fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID)
file, err := os.Open(util.RelativePath(job.Req.Src)) file, err := os.Open(util.RelativePath(job.Req.Src))
if err != nil { if err != nil {
job.SetErrorMsg("无法读取源文件", err) job.SetErrorMsg("Failed to read source file.", err)
return return
} }
@ -110,7 +110,7 @@ func (job *TransferTask) Do() {
// 获取源文件大小 // 获取源文件大小
fi, err := file.Stat() fi, err := file.Stat()
if err != nil { if err != nil {
job.SetErrorMsg("无法获取源文件大小", err) job.SetErrorMsg("Failed to get source file size.", err)
return return
} }
@ -122,7 +122,7 @@ func (job *TransferTask) Do() {
Size: uint64(size), Size: uint64(size),
}) })
if err != nil { if err != nil {
job.SetErrorMsg("文件上传失败", err) job.SetErrorMsg("Upload failed.", err)
return return
} }
@ -133,6 +133,6 @@ func (job *TransferTask) Do() {
} }
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
util.Log().Warning("无法发送转存成功通知到从机, %s", err) util.Log().Warning("Failed to send transfer success notification to master node: %s", err)
} }
} }

@ -3,6 +3,7 @@ package task
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
@ -94,6 +95,7 @@ func (job *TransferTask) Do() {
} }
successCount := 0 successCount := 0
errorList := make([]string, 0, len(job.TaskProps.Src))
for _, file := range job.TaskProps.Src { for _, file := range job.TaskProps.Src {
dst := path.Join(job.TaskProps.Dst, filepath.Base(file)) dst := path.Join(job.TaskProps.Dst, filepath.Base(file))
if job.TaskProps.TrimPath { if job.TaskProps.TrimPath {
@ -109,7 +111,7 @@ func (job *TransferTask) Do() {
// 获取从机节点 // 获取从机节点
node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) node := cluster.Default.GetNodeByID(job.TaskProps.NodeID)
if node == nil { if node == nil {
job.SetErrorMsg("从机节点不可用", nil) job.SetErrorMsg("Invalid slave node.", nil)
} }
// 切换为从机节点处理上传 // 切换为从机节点处理上传
@ -127,13 +129,17 @@ func (job *TransferTask) Do() {
} }
if err != nil { if err != nil {
job.SetErrorMsg("文件转存失败", err) errorList = append(errorList, err.Error())
} else { } else {
successCount++ successCount++
job.TaskModel.SetProgress(successCount) job.TaskModel.SetProgress(successCount)
} }
} }
if len(errorList) > 0 {
job.SetErrorMsg("Failed to transfer one or more file(s).", fmt.Errorf(strings.Join(errorList, "\n")))
}
} }
// NewTransferTask 新建中转任务 // NewTransferTask 新建中转任务

@ -16,14 +16,14 @@ type GeneralWorker struct {
// Do 执行任务 // Do 执行任务
func (worker *GeneralWorker) Do(job Job) { func (worker *GeneralWorker) Do(job Job) {
util.Log().Debug("开始执行任务") util.Log().Debug("Start executing task.")
job.SetStatus(Processing) job.SetStatus(Processing)
defer func() { defer func() {
// 致命错误捕获 // 致命错误捕获
if err := recover(); err != nil { if err := recover(); err != nil {
util.Log().Debug("任务执行出错,%s", err) util.Log().Debug("Failed to execute task: %s", err)
job.SetError(&JobError{Msg: "致命错误", Error: fmt.Sprintf("%s", err)}) job.SetError(&JobError{Msg: "Fatal error.", Error: fmt.Sprintf("%s", err)})
job.SetStatus(Error) job.SetStatus(Error)
} }
}() }()
@ -33,12 +33,12 @@ func (worker *GeneralWorker) Do(job Job) {
// 任务执行失败 // 任务执行失败
if err := job.GetError(); err != nil { if err := job.GetError(); err != nil {
util.Log().Debug("任务执行出错") util.Log().Debug("Failed to execute task.")
job.SetStatus(Error) job.SetStatus(Error)
return return
} }
util.Log().Debug("任务执行完成") util.Log().Debug("Task finished.")
// 执行完成 // 执行完成
job.SetStatus(Complete) job.SetStatus(Complete)
} }

@ -45,7 +45,7 @@ func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) {
case "png": case "png":
img, err = png.Decode(file) img, err = png.Decode(file)
default: default:
return nil, errors.New("未知的图像类型") return nil, errors.New("unknown image format")
} }
if err != nil { if err != nil {
return nil, err return nil, err

@ -22,7 +22,7 @@ func CreatNestedFile(path string) (*os.File, error) {
if !Exists(basePath) { if !Exists(basePath) {
err := os.MkdirAll(basePath, 0700) err := os.MkdirAll(basePath, 0700)
if err != nil { if err != nil {
Log().Warning("无法创建目录,%s", err) Log().Warning("Failed to create directory: %s", err)
return nil, err return nil, err
} }
} }

@ -79,8 +79,8 @@ func AnonymousGetContent(c *gin.Context) {
} }
} }
// AnonymousPermLink 文件签名后的永久链接 // AnonymousPermLink Deprecated 文件签名后的永久链接
func AnonymousPermLink(c *gin.Context) { func AnonymousPermLinkDeprecated(c *gin.Context) {
// 创建上下文 // 创建上下文
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -102,6 +102,39 @@ func AnonymousPermLink(c *gin.Context) {
} }
} }
// AnonymousPermLink 文件中转后的永久直链接
func AnonymousPermLink(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sourceLinkRaw, ok := c.Get("source_link")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
return
}
sourceLink := sourceLinkRaw.(*model.SourceLink)
service := &explorer.FileAnonymousGetService{
ID: sourceLink.FileID,
Name: sourceLink.File.Name,
}
res := service.Source(ctx, c)
// 是否需要重定向
if res.Code == -302 {
c.Redirect(302, res.Data.(string))
return
}
// 是否有错误发生
if res.Code != 0 {
c.JSON(200, res)
}
}
func GetSource(c *gin.Context) { func GetSource(c *gin.Context) {
// 创建上下文 // 创建上下文
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

@ -27,6 +27,7 @@ func SiteConfig(c *gin.Context) {
"captcha_type", "captcha_type",
"captcha_TCaptcha_CaptchaAppId", "captcha_TCaptcha_CaptchaAppId",
"register_enabled", "register_enabled",
"show_app_promotion",
) )
// 如果已登录,则同时返回用户信息和标签 // 如果已登录,则同时返回用户信息和标签

@ -16,10 +16,10 @@ import (
// InitRouter 初始化路由 // InitRouter 初始化路由
func InitRouter() *gin.Engine { func InitRouter() *gin.Engine {
if conf.SystemConfig.Mode == "master" { if conf.SystemConfig.Mode == "master" {
util.Log().Info("当前运行模式Master") util.Log().Info("Current running mode: Master.")
return InitMasterRouter() return InitMasterRouter()
} }
util.Log().Info("当前运行模式Slave") util.Log().Info("Current running mode: Slave.")
return InitSlaveRouter() return InitSlaveRouter()
} }
@ -108,7 +108,7 @@ func InitCORS(router *gin.Engine) {
// slave模式下未启动跨域的警告 // slave模式下未启动跨域的警告
if conf.SystemConfig.Mode == "slave" { if conf.SystemConfig.Mode == "slave" {
util.Log().Warning("当前作为存储端Slave运行但未启用跨域配置可能会导致 Master 端无法正常上传文件") util.Log().Warning("You are running Cloudreve as slave node, if you are using slave storage policy, please enable CORS feature in config file, otherwise file cannot be uploaded from Master site.")
} }
} }
@ -145,6 +145,15 @@ func InitMasterRouter() *gin.Engine {
*/ */
{ {
// Redirect file source link
source := r.Group("f")
{
source.GET(":id/:name",
middleware.HashID(hashid.SourceLinkID),
middleware.ValidateSourceLink(),
controllers.AnonymousPermLink)
}
// 全局设置相关 // 全局设置相关
site := v3.Group("site") site := v3.Group("site")
{ {
@ -197,6 +206,7 @@ func InitMasterRouter() *gin.Engine {
// 获取用户头像 // 获取用户头像
user.GET("avatar/:id/:size", user.GET("avatar/:id/:size",
middleware.HashID(hashid.UserID), middleware.HashID(hashid.UserID),
middleware.StaticResourceCache(),
controllers.GetUserAvatar, controllers.GetUserAvatar,
) )
} }
@ -208,11 +218,18 @@ func InitMasterRouter() *gin.Engine {
file := sign.Group("file") file := sign.Group("file")
{ {
// 文件外链(直接输出文件数据) // 文件外链(直接输出文件数据)
file.GET("get/:id/:name", controllers.AnonymousGetContent) file.GET("get/:id/:name",
middleware.Sandbox(),
middleware.StaticResourceCache(),
controllers.AnonymousGetContent,
)
// 文件外链(301跳转) // 文件外链(301跳转)
file.GET("source/:id/:name", controllers.AnonymousPermLink) file.GET("source/:id/:name", controllers.AnonymousPermLinkDeprecated)
// 下载文件 // 下载文件
file.GET("download/:id", controllers.Download) file.GET("download/:id",
middleware.StaticResourceCache(),
controllers.Download,
)
// 打包并下载文件 // 打包并下载文件
file.GET("archive/:sessionID/archive.zip", controllers.DownloadArchive) file.GET("archive/:sessionID/archive.zip", controllers.DownloadArchive)
} }
@ -445,7 +462,7 @@ func InitMasterRouter() *gin.Engine {
// 列出文件 // 列出文件
file.POST("list", controllers.AdminListFile) file.POST("list", controllers.AdminListFile)
// 预览文件 // 预览文件
file.GET("preview/:id", controllers.AdminGetFile) file.GET("preview/:id", middleware.Sandbox(), controllers.AdminGetFile)
// 删除 // 删除
file.POST("delete", controllers.AdminDeleteFile) file.POST("delete", controllers.AdminDeleteFile)
// 列出用户或外部文件系统目录 // 列出用户或外部文件系统目录
@ -555,9 +572,9 @@ func InitMasterRouter() *gin.Engine {
// 创建文件下载会话 // 创建文件下载会话
file.PUT("download/:id", controllers.CreateDownloadSession) file.PUT("download/:id", controllers.CreateDownloadSession)
// 预览文件 // 预览文件
file.GET("preview/:id", controllers.Preview) file.GET("preview/:id", middleware.Sandbox(), controllers.Preview)
// 获取文本文件内容 // 获取文本文件内容
file.GET("content/:id", controllers.PreviewText) file.GET("content/:id", middleware.Sandbox(), controllers.PreviewText)
// 取得Office文档预览地址 // 取得Office文档预览地址
file.GET("doc/:id", controllers.GetDocPreview) file.GET("doc/:id", controllers.GetDocPreview)
// 获取缩略图 // 获取缩略图

@ -318,12 +318,20 @@ func (service *AdminListService) Policies() serializer.Response {
// 统计每个策略的文件使用 // 统计每个策略的文件使用
statics := make(map[uint][2]int, len(res)) statics := make(map[uint][2]int, len(res))
policyIds := make([]uint, 0, len(res))
for i := 0; i < len(res); i++ { for i := 0; i < len(res); i++ {
policyIds = append(policyIds, res[i].ID)
}
rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds).
Select("policy_id,count(id),sum(size)").Group("policy_id").Rows()
for rows.Next() {
policyId := uint(0)
total := [2]int{} total := [2]int{}
row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID). rows.Scan(&policyId, &total[0], &total[1])
Select("count(id),sum(size)").Row()
row.Scan(&total[0], &total[1]) statics[policyId] = total
statics[res[i].ID] = total
} }
return serializer.Response{Data: map[string]interface{}{ return serializer.Response{Data: map[string]interface{}{

@ -109,6 +109,7 @@ func (service *AddUserService) Add() serializer.Response {
user.Email = service.User.Email user.Email = service.User.Email
user.GroupID = service.User.GroupID user.GroupID = service.User.GroupID
user.Status = service.User.Status user.Status = service.User.Status
user.TwoFactor = service.User.TwoFactor
// 检查愚蠢操作 // 检查愚蠢操作
if user.ID == 1 && user.GroupID != 1 { if user.ID == 1 && user.GroupID != 1 {

@ -27,6 +27,13 @@ type DownloadListService struct {
func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录 // 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown) downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown)
for key, download := range downloads {
node := cluster.Default.GetNodeByID(download.GetNodeID())
if node != nil {
downloads[key].NodeName = node.DBModel().Name
}
}
return serializer.BuildFinishedListResponse(downloads) return serializer.BuildFinishedListResponse(downloads)
} }
@ -35,12 +42,17 @@ func (service *DownloadListService) Downloading(c *gin.Context, user *model.User
// 查找下载记录 // 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready)
intervals := make(map[uint]int) intervals := make(map[uint]int)
for _, download := range downloads { for key, download := range downloads {
if _, ok := intervals[download.ID]; !ok { if _, ok := intervals[download.ID]; !ok {
if node := cluster.Default.GetNodeByID(download.GetNodeID()); node != nil { if node := cluster.Default.GetNodeByID(download.GetNodeID()); node != nil {
intervals[download.ID] = node.DBModel().Aria2OptionsSerialized.Interval intervals[download.ID] = node.DBModel().Aria2OptionsSerialized.Interval
} }
} }
node := cluster.Default.GetNodeByID(download.GetNodeID())
if node != nil {
downloads[key].NodeName = node.DBModel().Name
}
} }
return serializer.BuildDownloadingResponse(downloads, intervals) return serializer.BuildDownloadingResponse(downloads, intervals)

@ -175,7 +175,7 @@ func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response
// SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容 // SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容
// See: https://github.com/OneDrive/onedrive-api-docs/issues/935 // See: https://github.com/OneDrive/onedrive-api-docs/issues/935
if strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) { if (strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") || strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.cn")) && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) {
isSizeCheckFailed = false isSizeCheckFailed = false
} }
@ -239,7 +239,7 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response {
return ProcessCallback(service, c) return ProcessCallback(service, c)
} }
// PreProcess 对OneDrive客户端回调进行预处理验证 // PreProcess 对从机客户端回调进行预处理验证
func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response { func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response {
// 创建文件系统 // 创建文件系统
fs, err := filesystem.NewFileSystemFromCallback(c) fs, err := filesystem.NewFileSystemFromCallback(c)

@ -178,12 +178,13 @@ func (service *FileAnonymousGetService) Source(ctx context.Context, c *gin.Conte
} }
// 获取文件流 // 获取文件流
res, err := fs.SignURL(ctx, &fs.FileTarget[0], ttl := int64(model.GetIntSetting("preview_timeout", 60))
int64(model.GetIntSetting("preview_timeout", 60)), false) res, err := fs.SignURL(ctx, &fs.FileTarget[0], ttl, false)
if err != nil { if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err) return serializer.Err(serializer.CodeNotSet, err.Error(), err)
} }
c.Header("Cache-Control", fmt.Sprintf("max-age=%d", ttl))
return serializer.Response{ return serializer.Response{
Code: -302, Code: -302,
Data: res, Data: res,
@ -442,14 +443,39 @@ func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer.
} }
res := make([]serializer.Sources, 0, len(s.Raw().Items)) res := make([]serializer.Sources, 0, len(s.Raw().Items))
for _, id := range s.Raw().Items { files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID)
fs.FileTarget = []model.File{} if err != nil || len(files) == 0 {
sourceURL, err := fs.GetSource(ctx, id) return serializer.Err(serializer.CodeFileNotFound, "", err)
if len(fs.FileTarget) > 0 { }
getSourceFunc := func(file model.File) (string, error) {
fs.FileTarget = []model.File{file}
return fs.GetSource(ctx, file.ID)
}
// Create redirected source link if needed
if fs.User.Group.OptionsSerialized.RedirectedSource {
getSourceFunc = func(file model.File) (string, error) {
source, err := file.CreateOrGetSourceLink()
if err != nil {
return "", err
}
sourceLinkURL, err := source.Link()
if err != nil {
return "", err
}
return sourceLinkURL, nil
}
}
for _, file := range files {
sourceURL, err := getSourceFunc(file)
current := serializer.Sources{ current := serializer.Sources{
URL: sourceURL, URL: sourceURL,
Name: fs.FileTarget[0].Name, Name: file.Name,
Parent: fs.FileTarget[0].FolderID, Parent: file.FolderID,
} }
if err != nil { if err != nil {
@ -458,7 +484,6 @@ func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer.
res = append(res, current) res = append(res, current)
} }
}
return serializer.Response{ return serializer.Response{
Code: 0, Code: 0,

Loading…
Cancel
Save