diff --git a/main.go b/main.go index 96f25c7..7c79a84 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,7 @@ import ( ) func init() { - conf.Init() + conf.Init("conf/conf.ini") model.Init() } diff --git a/models/init.go b/models/init.go index 25a5a6c..6968833 100644 --- a/models/init.go +++ b/models/init.go @@ -1,7 +1,9 @@ package model import ( + "Cloudreve/pkg/conf" "Cloudreve/pkg/util" + "fmt" "github.com/jinzhu/gorm" "time" @@ -15,13 +17,32 @@ var DB *gorm.DB func Init() { //TODO 从配置文件中读取 包括DEBUG模式 util.Log().Info("初始化数据库连接\n") - db, err := gorm.Open("mysql", "root:root@(localhost)/v3?charset=utf8&parseTime=True&loc=Local") + + var ( + db *gorm.DB + err error + ) + if conf.DatabaseConfig.Type == "UNSET" { + //TODO 使用内置SQLite数据库 + } else { + db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", + conf.DatabaseConfig.User, + conf.DatabaseConfig.Password, + conf.DatabaseConfig.Host, + conf.DatabaseConfig.Name)) + } + + // 处理表前缀 + gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string { + return conf.DatabaseConfig.TablePrefix + defaultTableName + } + db.LogMode(true) //db.SetLogger(util.Log()) - // Error if err != nil { util.Log().Panic("连接数据库不成功", err) } + //设置连接池 //空闲 db.DB().SetMaxIdleConns(50) @@ -32,5 +53,6 @@ func Init() { DB = db + //执行迁移 migration() } diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 957af07..fe051f3 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -2,15 +2,11 @@ package conf import ( "Cloudreve/pkg/util" - "fmt" "github.com/go-ini/ini" ) -type Conf struct { - Database Database -} - -type Database struct { +// Database 数据库 +type database struct { Type string User string Password string @@ -19,27 +15,33 @@ type Database struct { TablePrefix string } -var database = &Database{ +var DatabaseConfig = &database{ Type: "UNSET", } var cfg *ini.File -func Init() { +// Init 初始化配置文件 +func Init(path string) { var err error //TODO 配置文件不存在时创建 - cfg, err = ini.Load("conf/conf.ini") + //TODO 配置合法性验证 + cfg, err = ini.Load(path) + if err != nil { + util.Log().Panic("无法解析配置文件 '%s': ", path, err) + } + err = mapSection("Database", DatabaseConfig) if err != nil { - util.Log().Panic("无法解析配置文件 'conf/conf.ini': ", err) + util.Log().Warning("配置文件 %s 分区解析失败: ", "Database", err) } - mapSection("Database", database) - fmt.Println(database) } -func mapSection(section string, confStruct interface{}) { - err := cfg.Section("Database").MapTo(database) +// mapSection 将配置文件的 Section 映射到结构体上 +func mapSection(section string, confStruct interface{}) error { + err := cfg.Section(section).MapTo(confStruct) if err != nil { - util.Log().Warning("配置文件 Database 分区解析失败") + return err } + return nil } diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go new file mode 100644 index 0000000..dad078e --- /dev/null +++ b/pkg/conf/conf_test.go @@ -0,0 +1,78 @@ +package conf + +import ( + "github.com/stretchr/testify/assert" + "io/ioutil" + "testing" +) + +// 测试Init日志路径错误 +func TestInitPanic(t *testing.T) { + asserts := assert.New(t) + + // 日志路径不存在时 + asserts.Panics(func() { + Init("not/exist/path") + }) + +} + +// TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时 +func TestInitDelimiterNotFound(t *testing.T) { + asserts := assert.New(t) + testCase := `[Database] +Type = mysql +User = root +Password233root +Host = 127.0.0.1:3306 +Name = v3 +TablePrefix = v3_` + err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) + if err != nil { + panic(err) + } + asserts.Panics(func() { + Init("testConf.ini") + }) +} + +// TestInitNoPanic 日志路径存在且合法时 +func TestInitNoPanic(t *testing.T) { + asserts := assert.New(t) + testCase := `[Database] +Type = mysql +User = root +Password = root +Host = 127.0.0.1:3306 +Name = v3 +TablePrefix = v3_` + err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) + if err != nil { + panic(err) + } + asserts.NotPanics(func() { + Init("testConf.ini") + }) +} + +func TestMapSection(t *testing.T) { + asserts := assert.New(t) + + //正常情况 + testCase := `[Database] +Type = mysql +User = root +Password:root +Host = 127.0.0.1:3306 +Name = v3 +TablePrefix = v3_` + err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) + if err != nil { + panic(err) + } + Init("testConf.ini") + err = mapSection("Database", DatabaseConfig) + asserts.NoError(err) + + // TODO 类型不匹配测试 +} diff --git a/pkg/util/logger.go b/pkg/util/logger.go index 522a056..ee84633 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -2,7 +2,6 @@ package util import ( "fmt" - "os" "time" ) @@ -36,7 +35,7 @@ func (ll *Logger) Panic(format string, v ...interface{}) { } msg := fmt.Sprintf("[Panic] "+format, v...) ll.Println(msg) - os.Exit(0) + panic(msg) } // Error 错误 @@ -80,7 +79,7 @@ func (ll *Logger) Print(v ...interface{}) { if LevelDebug > ll.level { return } - msg := fmt.Sprintf("[SQL] ", v...) + msg := fmt.Sprintf("[SQL] %s", v...) ll.Println(msg) }