diff --git a/pkg/conf/defaults.go b/pkg/conf/defaults.go index eabaee5e..ace4f755 100644 --- a/pkg/conf/defaults.go +++ b/pkg/conf/defaults.go @@ -1,52 +1,60 @@ package conf +import "github.com/cloudreve/Cloudreve/v3/pkg/util" + // RedisConfig Redis服务器配置 var RedisConfig = &redis{ - Network: "tcp", - Server: "", - Password: "", - DB: "0", + Network: util.EnvStr("REDIS_NETWORK", "tcp"), + Server: util.EnvStr("REDIS_SERVER", ""), + Password: util.EnvStr("REDIS_PASSWORD", ""), + DB: util.EnvStr("REDIS_DB", "0"), } // DatabaseConfig 数据库配置 var DatabaseConfig = &database{ - Type: "UNSET", - Charset: "utf8", - DBFile: "cloudreve.db", - Port: 3306, + Type: util.EnvStr("DATABASE_TYPE", "UNSET"), + User: util.EnvStr("DATABASE_USER", "root"), + Password: util.EnvStr("DATABASE_PASSWORD", ""), + Host: util.EnvStr("DATABASE_HOST", "localhost"), + Name: util.EnvStr("DATABASE_NAME", "cloudreve"), + TablePrefix: util.EnvStr("DATABASE_TABLE_PREFIX", ""), + DBFile: util.EnvStr("DATABASE_DBFILE", "data.db"), + Port: util.EnvInt("DATABASE_PORT", 3306), + Charset: util.EnvStr("DATABASE_CHARSET", "utf8"), } // SystemConfig 系统公用配置 var SystemConfig = &system{ - Debug: false, - Mode: "master", - Listen: ":5212", + Debug: util.EnvStr("DEBUG", "false") == "true", + Mode: util.EnvStr("MODE", "master"), + Listen: util.EnvStr("LISTEN", ":5212"), } // CORSConfig 跨域配置 var CORSConfig = &cors{ - AllowOrigins: []string{"UNSET"}, - AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"}, - AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"}, - AllowCredentials: false, - ExposeHeaders: nil, + AllowOrigins: util.EnvArr("CORS_ALLOW_ORIGINS", []string{"UNSET"}), + AllowMethods: util.EnvArr("CORS_ALLOW_METHODS", []string{"PUT", "POST", "GET", "OPTIONS"}), + AllowHeaders: util.EnvArr("CORS_ALLOW_HEADERS", []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"}), + AllowCredentials: util.EnvStr("CORS_ALLOW_CREDENTIALS", "false") == "true", + ExposeHeaders: util.EnvArr("CORS_EXPOSE_HEADERS", nil), } // SlaveConfig 从机配置 var SlaveConfig = &slave{ - CallbackTimeout: 20, - SignatureTTL: 60, + Secret: util.EnvStr("SLAVE_SECRET", ""), + CallbackTimeout: util.EnvInt("SLAVE_CALLBACK_TIMEOUT", 10), + SignatureTTL: util.EnvInt("SLAVE_SIGNATURE_TTL", 10), } var SSLConfig = &ssl{ - Listen: ":443", - CertPath: "", - KeyPath: "", + Listen: util.EnvStr("SSL_LISTEN", ":443"), + CertPath: util.EnvStr("SSL_CERT_PATH", ""), + KeyPath: util.EnvStr("SSL_KEY_PATH", ""), } var UnixConfig = &unix{ - Listen: "", - ProxyHeader: "X-Forwarded-For", + Listen: util.EnvStr("UNIX_LISTEN", ""), + ProxyHeader: util.EnvStr("UNIX_PROXY_HEADER", "X-Forwarded-For"), } var OptionOverwrite = map[string]interface{}{} diff --git a/pkg/util/env.go b/pkg/util/env.go new file mode 100644 index 00000000..fd53dd71 --- /dev/null +++ b/pkg/util/env.go @@ -0,0 +1,40 @@ +package util + +import ( + "os" + "strconv" + "strings" +) + +// EnvStr returns the value of the environment variable named by the key. +func EnvStr(key, defaultValue string) string { + if value, exist := os.LookupEnv(key); exist { + return value + } + + return defaultValue +} + +// EnvInt returns the value of the environment variable named by the key. +func EnvInt(key string, defaultValue int) int { + if value, exist := os.LookupEnv(key); exist { + number, err := strconv.Atoi(value) + if err != nil { + // I think that we should log this error + return defaultValue + } + + return number + } + + return defaultValue +} + +// EnvArr returns the value of the environment variable named by the key. +func EnvArr(key string, defaultValue []string) []string { + if value, exist := os.LookupEnv(key); exist { + return strings.Split(value, ",") + } + + return defaultValue +} diff --git a/pkg/util/env_test.go b/pkg/util/env_test.go new file mode 100644 index 00000000..55a569b0 --- /dev/null +++ b/pkg/util/env_test.go @@ -0,0 +1,63 @@ +package util + +import ( + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func TestEnvStr(t *testing.T) { + asserts := assert.New(t) + + { + asserts.Equal("default", EnvStr("not_exist", "default")) + } + { + err := os.Setenv("exist", "value") + asserts.NoError(err) + asserts.Equal("value", EnvStr("exist", "default")) + + err = os.Unsetenv("exist") + asserts.NoError(err) + asserts.Equal("default", EnvStr("exist", "default")) + } +} + +func TestEnvInt(t *testing.T) { + asserts := assert.New(t) + + { + asserts.Equal(1, EnvInt("not_exist", 1)) + } + { + err := os.Setenv("exist", "2") + asserts.NoError(err) + asserts.Equal(2, EnvInt("exist", 1)) + + err = os.Unsetenv("exist") + asserts.NoError(err) + asserts.Equal(1, EnvInt("exist", 1)) + } + { + err := os.Setenv("exist", "not_number") + asserts.NoError(err) + asserts.Equal(1, EnvInt("exist", 1)) + } +} + +func TestEnvArr(t *testing.T) { + asserts := assert.New(t) + + { + asserts.Equal([]string{"default"}, EnvArr("not_exist", []string{"default"})) + } + { + err := os.Setenv("exist", "value1,value2") + asserts.NoError(err) + asserts.Equal([]string{"value1", "value2"}, EnvArr("exist", []string{"default"})) + + err = os.Unsetenv("exist") + asserts.NoError(err) + asserts.Equal([]string{"default"}, EnvArr("exist", []string{"default"})) + } +}